diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bfbbe60cc..ee4186650 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -3,10 +3,20 @@ name: ci on: push: branches: - - master + - main + paths: + - 'src/**' + - 'test/**' + - '.github/workflows/*.yml' + - 'pom.xml' pull_request: branches: - - master + - main + paths: + - 'src/**' + - 'test/**' + - '.github/workflows/*.yml' + - 'pom.xml' jobs: misc: @@ -14,38 +24,46 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 + - uses: actions/checkout@v4 - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Verify run: mvn -B verify -DskipTests=true - name: Misc Tests - run: mvn -B '-Dtest=!sqlancer.dbms.**' test + run: mvn -Djacoco.skip=true -B '-Dtest=!sqlancer.dbms.**,!sqlancer.qpg.**' test + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Naming Convention Tests + run: python src/check_names.py citus: name: DBMS Tests (Citus) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 + - uses: actions/checkout@v4 - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build SQLancer run: mvn -B package -DskipTests=true - name: Set up Citus run: | echo "deb http://apt.postgresql.org/pub/repos/apt/ `lsb_release -cs`-pgdg main" | sudo tee /etc/apt/sources.list.d/pgdg.list curl https://install.citusdata.com/community/deb.sh | sudo bash - sudo apt-get -y install postgresql-13-citus-10.1 + sudo sed -i 's/noble/jammy/g' /etc/apt/sources.list.d/citusdata_community.list # https://github.com/citusdata/citus/issues/7692 + sudo apt-get update + sudo apt-get -y install postgresql-17-citus-13.0 sudo chown -R $USER:$USER /var/run/postgresql - export PATH=/usr/lib/postgresql/13/bin:$PATH + export PATH=/usr/lib/postgresql/17/bin:$PATH cd ~ mkdir -p citus/coordinator citus/worker1 citus/worker2 initdb -D citus/coordinator @@ -70,281 +88,628 @@ jobs: psql -c "SELECT * from citus_add_node('localhost', 9701);" -p 9700 -U $USER -d test psql -c "SELECT * from citus_add_node('localhost', 9702);" -p 9700 -U $USER -d test - name: Run Tests - run: CITUS_AVAILABLE=true mvn -Dtest=TestCitus test + run: CITUS_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestCitus test + + cnosdb: + name: DBMS Tests (CnosDB, creation only) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer + run: mvn -B package -DskipTests=true + - name: Set up CnosDB + run: | + docker pull cnosdb/cnosdb:community-latest + docker run --name cnosdb -p 8902:8902 -d cnosdb/cnosdb:community-latest + until nc -z 127.0.0.1 8902 2>/dev/null; do sleep 1; done + - name: Run Tests + run: | + CNOSDB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestCnosDBNoREC test + sleep 20 + CNOSDB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestCnosDBTLP test clickhouse: name: DBMS Tests (ClickHouse) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 + - uses: actions/checkout@v4 - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build SQLancer run: mvn -B package -DskipTests=true - name: Set up ClickHouse run: | - docker pull yandex/clickhouse-server:latest - docker run --ulimit nofile=262144:262144 --name clickhouse-server -p8123:8123 -d yandex/clickhouse-server:latest - sleep 5 + docker pull clickhouse/clickhouse-server:24.3.1.2672 + docker run --ulimit nofile=262144:262144 --name clickhouse-server -p8123:8123 -d clickhouse/clickhouse-server:24.3.1.2672 + until curl -sf http://127.0.0.1:8123/ping 2>/dev/null; do sleep 1; done - name: Run Tests - run: CLICKHOUSE_AVAILABLE=true mvn -Dtest=ClickHouseBinaryComparisonOperationTest test - + run: CLICKHOUSE_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=ClickHouseBinaryComparisonOperationTest,TestClickHouse,ClickHouseOperatorsVisitorTest,ClickHouseToStringVisitorTest test + - name: Show fatal errors + run: docker exec clickhouse-server grep Fatal /var/log/clickhouse-server/clickhouse-server.log || echo No Fatal Errors found + - name: Teardown ClickHouse server + run: | + docker stop clickhouse-server + docker rm clickhouse-server cockroachdb: name: DBMS Tests (CockroachDB) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 with: - fetch-depth: 0 + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer + run: mvn -B package -DskipTests=true + - name: Set up CockroachDB + run: | + wget -qO- https://binaries.cockroachdb.com/cockroach-v24.2.0.linux-amd64.tgz | tar xvz + cd cockroach-v24.2.0.linux-amd64/ && ./cockroach start-single-node --insecure & + until cockroach-v24.2.0.linux-amd64/cockroach sql --insecure -e "SELECT 1" 2>/dev/null; do sleep 2; done + - name: Create SQLancer user + run: cd cockroach-v24.2.0.linux-amd64/ && ./cockroach sql --insecure -e "CREATE USER sqlancer; GRANT admin to sqlancer" && cd .. + - name: Run Tests + run: | + COCKROACHDB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestCockroachDBNoREC test + COCKROACHDB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestCockroachDBTLP test + COCKROACHDB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestCockroachDBCERT test + + cockroachdb-qpg: + name: QPG Tests (CockroachDB) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build SQLancer run: mvn -B package -DskipTests=true - name: Set up CockroachDB run: | - wget -qO- https://binaries.cockroachdb.com/cockroach-v21.1.7.linux-amd64.tgz | tar xvz - cd cockroach-v21.1.7.linux-amd64/ && ./cockroach start-single-node --insecure & - sleep 10 + wget -qO- https://binaries.cockroachdb.com/cockroach-v24.2.0.linux-amd64.tgz | tar xvz + cd cockroach-v24.2.0.linux-amd64/ && ./cockroach start-single-node --insecure & + until cockroach-v24.2.0.linux-amd64/cockroach sql --insecure -e "SELECT 1" 2>/dev/null; do sleep 2; done - name: Create SQLancer user - run: cd cockroach-v21.1.7.linux-amd64/ && ./cockroach sql --insecure -e "CREATE USER sqlancer; GRANT admin to sqlancer" && cd .. + run: cd cockroach-v24.2.0.linux-amd64/ && ./cockroach sql --insecure -e "CREATE USER sqlancer; GRANT admin to sqlancer" && cd .. - name: Run Tests - run: COCKROACHDB_AVAILABLE=true mvn -Dtest=TestCockroachDB test + run: COCKROACHDB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestCockroachDBQPG test databend: name: DBMS Tests (Databend) runs-on: ubuntu-latest + services: + databend: + image: datafuselabs/databend:v1.2.896-nightly + env: + QUERY_DEFAULT_USER: sqlancer + QUERY_DEFAULT_PASSWORD: sqlancer + ports: + - 8000:8000 + - 3307:3307 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 with: - fetch-depth: 0 + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer + run: mvn -B package -DskipTests=true + - name: Run Tests + run: | + DATABEND_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestDatabendTLP test + DATABEND_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestDatabendNoREC test + DATABEND_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestDatabendPQS test + + datafusion: + name: DBMS Tests (DataFusion) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + override: true + - name: Build DataFusion Server + run: | + cd src/sqlancer/datafusion/server/datafusion_server + cargo build + - name: Start DataFusion Server + run: | + cd src/sqlancer/datafusion/server/datafusion_server + cargo run & - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build SQLancer run: mvn -B package -DskipTests=true - - name: Set up Databend + - name: Wait for DataFusion Server run: | - sudo apt update - sudo apt install mysql-client - LASTEST_TAG=$(curl -s GET https://api.github.com/repos/datafuselabs/databend/tags\?per_page\=1 | jq -r '.[].name') - curl -LJO https://github.com/datafuselabs/databend/releases/download/${LASTEST_TAG}/databend-${LASTEST_TAG}-x86_64-unknown-linux-musl.tar.gz - mkdir ./databend && tar xzvf databend-${LASTEST_TAG}-x86_64-unknown-linux-musl.tar.gz -C ./databend - ./databend/bin/databend-query & - - name: Create SQLancer user - run: mysql -uroot -h127.0.0.1 -P3307 -e "CREATE USER 'sqlancer' IDENTIFIED BY 'sqlancer'; GRANT ALL ON *.* TO sqlancer;" + for i in $(seq 1 30); do + if nc -z 127.0.0.1 50051 2>/dev/null; then + echo "DataFusion server is ready" + exit 0 + fi + echo "Waiting for DataFusion server... ($i/30)" + sleep 10 + done + echo "DataFusion server failed to start within 300s" + exit 1 - name: Run Tests run: | - DATABEND_AVAILABLE=true mvn -Dtest=TestDatabend test + DATAFUSION_AVAILABLE=true mvn -Djacoco.skip=true test -Pdatafusion-tests duckdb: name: DBMS Tests (DuckDB) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 + - uses: actions/checkout@v4 - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build run: mvn -B package -DskipTests=true - name: DuckDB Tests - run: mvn -Dtest=TestDuckDB test + run: | + mvn -Djacoco.skip=true -Dtest=TestDuckDBTLP test + mvn -Djacoco.skip=true -Dtest=TestDuckDBNoREC test h2: name: DBMS Tests (H2) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer + run: mvn -B package -DskipTests=true + - name: Run Tests + run: mvn -Djacoco.skip=true -Dtest=TestH2 test + + hive: + name: DBMS Tests (Hive) + runs-on: ubuntu-latest + services: + metastore: + image: apache/hive:4.0.1 + env: + SERVICE_NAME: 'metastore' + ports: + - 9083:9083 + volumes: + - warehouse:/opt/hive/data/warehouse + hiveserver2: + image: apache/hive:4.0.1 + env: + SERVICE_NAME: 'hiveserver2' + ports: + - 10000:10000 + - 10002:10002 + volumes: + - warehouse:/opt/hive/data/warehouse + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer + run: mvn -B package -DskipTests=true + - name: Run Tests + run: HIVE_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestHiveTLP test + + spark: + name: DBMS Tests (Spark) + runs-on: ubuntu-latest + + services: + spark: + image: apache/spark:3.5.1 + ports: + - 10000:10000 + + command: >- + /opt/spark/bin/spark-submit + --class org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 + --name "Thrift JDBC/ODBC Server" + --master local[*] + --driver-memory 4g + --conf spark.hive.server2.thrift.port=10000 + --conf spark.sql.warehouse.dir=/tmp/spark-warehouse + spark-internal + + steps: + - uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer run: mvn -B package -DskipTests=true + - name: Run Tests - run: H2_AVAILABLE=true mvn -Dtest=TestH2 test + run: SPARK_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestSparkTLP test + + hsqldb: + name: DBMS Tests (HSQLDB) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer + run: mvn -B package -DskipTests=true + - name: Run Tests + run: | + mvn -Djacoco.skip=true -Dtest=TestHSQLDBNoREC test + mvn -Djacoco.skip=true -Dtest=TestHSQLDBTLP test mariadb: name: DBMS Tests (MariaDB) - runs-on: ubuntu-18.04 + runs-on: ubuntu-latest + services: + mysql: + image: mariadb:11.7.2 + env: + MYSQL_ROOT_PASSWORD: root + ports: + - 3306:3306 + options: --health-cmd="healthcheck.sh --connect --innodb_initialized" --health-interval=10s --health-timeout=5s --health-retries=10 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 with: - fetch-depth: 0 + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer + run: mvn -B package -DskipTests=true + - name: Create SQLancer User + run: sudo mysql -h 127.0.0.1 -uroot -proot -e "CREATE USER 'sqlancer'@'%' IDENTIFIED BY 'sqlancer'; GRANT ALL PRIVILEGES ON * . * TO 'sqlancer'@'%';" + - name: Run Tests + run: MARIADB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestMariaDB test + + materialize: + name: DBMS Tests (Materialize) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Materialize + run: | + docker pull materialize/materialized:latest + docker run -d -p6875:6875 -p6877:6877 materialize/materialized:latest + until pg_isready -h localhost -p 6875 -U materialize; do sleep 1; done - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build SQLancer run: mvn -B package -DskipTests=true - - name: Install MariaDB + - name: Run Tests run: | - sudo apt-key adv --recv-keys --keyserver hkp://keyserver.ubuntu.com:80 0xF1656F24C74CD1D8 - sudo add-apt-repository 'deb [arch=amd64,arm64,ppc64el] http://sfo1.mirrors.digitalocean.com/mariadb/repo/10.3/ubuntu bionic main' - sudo apt update - sudo apt install mariadb-server - sudo systemctl start mariadb - - name: Create SQLancer User - run: sudo mysql -uroot -proot -e "CREATE USER 'sqlancer'@'localhost' IDENTIFIED BY 'sqlancer'; GRANT ALL PRIVILEGES ON * . * TO 'sqlancer'@'localhost';" + MATERIALIZE_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestMaterializeNoREC + MATERIALIZE_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestMaterializeTLP + MATERIALIZE_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestMaterializePQS + + materialize-qpg: + name: QPG Tests (Materialize) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Materialize + run: | + docker pull materialize/materialized:latest + docker run -d -p6875:6875 -p6877:6877 materialize/materialized:latest + until pg_isready -h localhost -p 6875 -U materialize; do sleep 1; done + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer + run: mvn -B package -DskipTests=true - name: Run Tests - run: MARIADB_AVAILABLE=true mvn -Dtest=TestMariaDB test + run: | + MATERIALIZE_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestMaterializeQPG + MATERIALIZE_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestMaterializeQueryPlan mysql: - name: DBMS Tests (MySQL) - runs-on: ubuntu-18.04 + name: DBMS Tests (MySQL, CERT creation only) + runs-on: ubuntu-latest + services: + mysql: + image: mysql:8.4 + env: + MYSQL_ROOT_PASSWORD: root + ports: + - 3306:3306 + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=10 steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 + - uses: actions/checkout@v4 - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build SQLancer run: mvn -B package -DskipTests=true - - name: Set up MySQL - run: | - sudo apt-get install libssl-dev libmecab2 libjson-perl mecab-ipadic-utf8 - sudo apt-get remove mysql-* - wget -q https://dev.mysql.com/get/Downloads/MySQL-8.0/mysql-server_8.0.20-1ubuntu18.04_amd64.deb-bundle.tar - tar -xvf mysql-server_8.0.20-1ubuntu18.04_amd64.deb-bundle.tar - sudo dpkg -i *.deb - sudo systemctl start mysql - name: Create SQLancer user - run: mysql -uroot -proot -e "CREATE USER 'sqlancer'@'localhost' IDENTIFIED BY 'sqlancer'; GRANT ALL PRIVILEGES ON * . * TO 'sqlancer'@'localhost';" + run: mysql -h 127.0.0.1 -uroot -proot -e "CREATE USER 'sqlancer'@'%' IDENTIFIED BY 'sqlancer'; GRANT ALL PRIVILEGES ON * . * TO 'sqlancer'@'%';" - name: Run Tests run: | - MYSQL_AVAILABLE=true mvn test -Dtest=TestMySQLPQS - MYSQL_AVAILABLE=true mvn test -Dtest=TestMySQLTLP + MYSQL_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestMySQLPQS + MYSQL_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestMySQLTLP + MYSQL_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestMySQLCERT + MYSQL_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestMySQLDQE + oceanbase: + name: DBMS Tests (OceanBase) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer + run: mvn -B package -DskipTests=true + - name: Set up OceanBase + run: | + docker run -p 2881:2881 --name oceanbase-ce -e MODE=mini -d oceanbase/oceanbase-ce:4.2.1-lts + until mysql -h127.1 -uroot@test -P2881 --connect-timeout=3 -Doceanbase -A -e "SELECT 1" 2>/dev/null; do sleep 5; done + mysql -h127.1 -uroot@test -P2881 -Doceanbase -A -e"CREATE USER 'sqlancer'@'%' IDENTIFIED BY 'sqlancer'; GRANT ALL PRIVILEGES ON * . * TO 'sqlancer'@'%';" + - name: Run Tests + run: | + OCEANBASE_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestOceanBaseNoREC + OCEANBASE_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestOceanBasePQS + OCEANBASE_AVAILABLE=true mvn -Djacoco.skip=true test -Dtest=TestOceanBaseTLP postgres: name: DBMS Tests (PostgreSQL) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 + - uses: actions/checkout@v4 - name: Set up PostgreSQL - uses: harmon758/postgresql-action@v1 + uses: harmon758/postgresql-action@v1.0.0 with: - postgresql version: '12' + postgresql version: '18' postgresql user: 'sqlancer' postgresql password: 'sqlancer' postgresql db: 'test' - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build SQLancer run: mvn -B package -DskipTests=true - name: Run Tests - run: POSTGRES_AVAILABLE=true mvn -Dtest=TestPostgres test + run: | + POSTGRES_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestPostgresPQS test + POSTGRES_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestPostgresTLP test + POSTGRES_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestPostgresNoREC test + POSTGRES_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestPostgresCERT test + presto: + name: DBMS Tests (Presto) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Set up Presto + run: | + docker pull prestodb/presto:latest + echo "connector.name=memory" >> memory.properties + docker run -p 8080:8080 -d -v ./memory.properties:/opt/presto-server/etc/catalog/memory.properties --name presto prestodb/presto:latest + until curl -sf http://127.0.0.1:8080/v1/info 2>/dev/null; do sleep 2; done + - name: Build SQLancer + run: mvn -B package -DskipTests=true + - name: Run Tests + run: | + PRESTO_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestPrestoNoREC test + docker restart presto && until curl -sf http://127.0.0.1:8080/v1/info 2>/dev/null; do sleep 2; done + PRESTO_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestPrestoTLP test sqlite: name: DBMS Tests (SQLite) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 + - uses: actions/checkout@v4 - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build run: mvn -B package -DskipTests=true - name: SQLite Tests - run: | - mvn -Dtest=TestSQLitePQS test - mvn -Dtest=TestSQLite3 test + run: | + mvn -Djacoco.skip=true -Dtest=TestSQLitePQS test + mvn -Djacoco.skip=true -Dtest=TestSQLiteTLP test + mvn -Djacoco.skip=true -Dtest=TestSQLiteNoREC test + mvn -Djacoco.skip=true -Dtest=TestSQLiteCODDTest test + + sqlite-qpg: + name: QPG Tests (SQLite) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build + run: mvn -B package -DskipTests=true + - name: SQLite Tests for QPG + run: | + mvn -Djacoco.skip=true -Dtest=TestSQLiteQPG test tidb: - name: DBMS Tests (TiDB) + name: DBMS Tests (TiDB, TLP creation only) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 + - uses: actions/checkout@v4 - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v4 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' + cache: 'maven' - name: Build SQLancer run: mvn -B package -DskipTests=true - name: Set up TiDB run: | - docker pull pingcap/tidb:latest - docker run --name tidb-server -d -p 4000:4000 pingcap/tidb:latest - sleep 10 + docker pull hawkingrei/tidb-playground:nightly-2025-09-16 + docker run --name tidb-server -d -p 4000:4000 hawkingrei/tidb-playground:nightly-2025-09-16 + until mysql -h 127.0.0.1 -P 4000 -u root --connect-timeout=3 -e "SELECT 1" 2>/dev/null; do sleep 3; done - name: Create SQLancer user - run: sudo mysql -h 127.0.0.1 -P 4000 -u root -D test -e "CREATE USER 'sqlancer'@'%' IDENTIFIED WITH mysql_native_password BY 'sqlancer'; GRANT ALL PRIVILEGES ON *.* TO 'sqlancer'@'%' WITH GRANT OPTION; FLUSH PRIVILEGES;" + run: mysql -h 127.0.0.1 -P 4000 -u root -D test -e "CREATE USER 'sqlancer'@'%' IDENTIFIED WITH mysql_native_password BY 'sqlancer'; GRANT ALL PRIVILEGES ON *.* TO 'sqlancer'@'%' WITH GRANT OPTION; FLUSH PRIVILEGES;" - name: Run Tests - run: TIDB_AVAILABLE=true mvn -Dtest=TestTiDB test + run: | + TIDB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestTiDBTLP test + TIDB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestTiDBCERT test - java13: - name: Java 13 Compatibility (DuckDB) + tidb-qpg: + name: QPG Tests (TiDB) runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - name: Set up JDK 13 - uses: actions/setup-java@v1 + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 with: - java-version: 13 - - name: Build + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer run: mvn -B package -DskipTests=true - - name: Shortly run DuckDB - run: cd target && java -jar $(ls | grep -P 'sqlancer-[0-9.]*.jar') --num-threads 4 --timeout-seconds 30 --num-queries 0 duckdb + - name: Set up TiDB + run: | + docker pull hawkingrei/tidb-playground:nightly-2025-09-16 + docker run --name tidb-server -d -p 4000:4000 hawkingrei/tidb-playground:nightly-2025-09-16 + until mysql -h 127.0.0.1 -P 4000 -u root --connect-timeout=3 -e "SELECT 1" 2>/dev/null; do sleep 3; done + - name: Create SQLancer user + run: mysql -h 127.0.0.1 -P 4000 -u root -D test -e "CREATE USER 'sqlancer'@'%' IDENTIFIED WITH mysql_native_password BY 'sqlancer'; GRANT ALL PRIVILEGES ON *.* TO 'sqlancer'@'%' WITH GRANT OPTION; FLUSH PRIVILEGES;" + - name: Run Tests + run: TIDB_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestTiDBQPG test - java14: - name: Java 14 Compatibility (DuckDB) + yugabyte: + name: DBMS Tests (YugabyteDB) runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - name: Set up JDK 14 - uses: actions/setup-java@v1 + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 with: - java-version: 14 - - name: Build + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: Build SQLancer run: mvn -B package -DskipTests=true - - name: Shortly run DuckDB - run: cd target && java -jar $(ls | grep -P 'sqlancer-[0-9.]*.jar') --num-threads 4 --timeout-seconds 30 --num-queries 0 duckdb - + - name: Set up Yugabyte + run: | + docker pull yugabytedb/yugabyte:latest + docker run -d --name yugabyte -p7000:7000 -p9000:9000 -p5433:5433 -p9042:9042 yugabytedb/yugabyte:latest bin/yugabyted start --daemon=false + until pg_isready -h localhost -p 5433; do sleep 1; done + - name: Run Tests + run: | + YUGABYTE_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestYSQLNoREC test + YUGABYTE_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestYSQLTLP test + YUGABYTE_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestYSQLPQS test + YUGABYTE_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestYCQL test - java15: - name: Java 15 EA Compatibility (DuckDB) + doris: + name: DBMS Tests (Apache Doris) runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - name: Set up JDK 15 - uses: actions/setup-java@v1 + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 with: - java-version: 15-ea - - name: Build + distribution: 'temurin' + java-version: '11' + cache: 'maven' + - name: install mysql client + run: | + sudo apt update + sudo apt install mysql-client --assume-yes + - name: Set up Apache Doris + run: | + sudo sysctl -w vm.max_map_count=2000000 + wget -q https://apache-doris-releases.oss-accelerate.aliyuncs.com/apache-doris-2.1.4-bin-x64.tar.gz + tar zxf apache-doris-2.1.4-bin-x64.tar.gz + mv apache-doris-2.1.4-bin-x64 apache-doris + sudo swapoff -a + cd apache-doris/fe + ./bin/start_fe.sh --daemon + cd ../be + ./bin/start_be.sh --daemon + + until mysql -u root -h 127.0.0.1 --port 9030 --connect-timeout=3 -e "SELECT 1" 2>/dev/null; do sleep 3; done + IP=$(hostname -I | awk '{print $1}') + mysql -u root -h 127.0.0.1 --port 9030 -e "ALTER SYSTEM ADD BACKEND '${IP}:9050';" + mysql -u root -h 127.0.0.1 --port 9030 -e "CREATE USER 'sqlancer' IDENTIFIED BY 'sqlancer'; GRANT ALL ON *.* TO sqlancer;" + - name: Build SQLancer run: mvn -B package -DskipTests=true - - name: Shortly run DuckDB - run: cd target && java -jar $(ls | grep -P 'sqlancer-[0-9.]*.jar') --num-threads 4 --timeout-seconds 30 --num-queries 0 duckdb + - name: Run Tests + run: | + DORIS_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestDorisNoREC test + DORIS_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestDorisPQS test + DORIS_AVAILABLE=true mvn -Djacoco.skip=true -Dtest=TestDorisTLP test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 342a358f3..6194c1529 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,9 +9,10 @@ jobs: steps: - uses: actions/checkout@v2 - name: Set up Maven Central Repository - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' server-id: ossrh server-username: MAVEN_USERNAME server-password: MAVEN_PASSWORD @@ -29,9 +30,10 @@ jobs: - name: Check out the repo uses: actions/checkout@v2 - name: Set up JDK 11 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: - java-version: 11 + distribution: 'temurin' + java-version: '11' - name: Build SQLancer run: mvn -B package -DskipTests=true - name: Push to Docker Hub diff --git a/.gitignore b/.gitignore index 48efcf83a..d7cbeb55f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,7 @@ target/ .classpath -.settings/org.eclipse.core.resources.prefs -.settings/org.eclipse.m2e.core.prefs -.settings/org.eclipse.jdt.core.prefs +.settings/ +.vscode .project .checkstyle *.DS_Store @@ -11,3 +10,7 @@ SQLancer.iml dependency-reduced-pom.xml database0.db databaseconnectiontest.db +database*.log +database*.properties +database*.script +databases/ \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 81177e352..ea5baea1c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ # Development -## Working with Eclipse +## Working with Eclipse [[Video Guide]](https://www.youtube.com/watch?v=KsuGrOLKb9Q) Developing SQLancer using Eclipse is expected to work well. You can import SQLancer with a single step: @@ -12,7 +12,23 @@ If you do not find an option to import Maven projects, you might need to install ## Implementing Support for a New DBMS -The DuckDB implementation provides a good template for a new implementation. The `DuckDBProvider` class is the central class that manages the creation of the databases and executes the selected test oracles. Try to copy its structure for the new DBMS that you want to implement, and start by generate databases (without implementing a test oracle). As part of this, you will also need to implement the equivalent of `DuckDBSchema`, which represents the database schema of the generated database. After you can successfully generate databases, the next step is to generate one of the test oracles. For example, you might want to implement NoREC (see `DuckDBNoRECOracle` or `DuckDBQueryPartitioningWhereTester` for TLP). As part of this, you must also implement a random expression generator (see `DuckDBExpressionGenerator`) and a visitor to derive the textual representation of an expression (see `DuckDBToStringVisitor`). +The DuckDB implementation provides a good template for a new implementation. The `DuckDBProvider` class is the central class that manages the creation of the databases and executes the selected test oracles. Try to copy its structure for the new DBMS that you want to implement, and start by generate databases (without implementing a test oracle). As part of this, you will also need to implement the equivalent of `DuckDBSchema`, which represents the database schema of the generated database. After you can successfully generate databases, the next step is to generate one of the test oracles. For example, you might want to implement NoREC (see enum value `NOREC` in `DuckDBOracleFactory`). As part of this, you must also implement a random expression generator (see `DuckDBExpressionGenerator`) and a visitor to derive the textual representation of an expression (see `DuckDBToStringVisitor`). + +Please consider the following suggestions when creating a PR to contribute a new DBMS: +* Ensure that `mvn verify -DskipTests=true` does not result in style violations. +* Add a [CI test](https://github.com/sqlancer/sqlancer/blob/master/.github/workflows/main.yml) to ensure that future changes to SQLancer are unlikely to break the newly-supported DBMS. It is reasonable to do this in a follow-up PR—please indicate whether you plan to do so in the PR description. +* Add the DBMS' name to the [check_names.py](https://github.com/sqlancer/sqlancer/blob/master/src/check_names.py) script, which ensures adherence to a common prefix in the Java classes. +* Add the DBMS' name to the [README.md](https://github.com/sqlancer/sqlancer/blob/master/README.md#supported-dbms) file. +* It would be easier to review multiple smaller PRs, than one PR that contains the complete implementation. Consider contributing parts of your implementation as you work on their implementation. + +### Expected Errors + +Most statements have an [ExpectedError](https://github.com/sqlancer/sqlancer/blob/aa0c0eccba4eefa75bfd518f608c9222c692c11d/src/sqlancer/common/query/ExpectedErrors.java) object associated with them. This object essentially contains a list of errors, one of which the database system might return if it cannot successfully execute the statement. These errors are typically added through a trial-and-error process while considering various tradeoffs. For example, consider the [DuckDBInsertGenerator](https://github.com/sqlancer/sqlancer/blob/aa0c0eccba4eefa75bfd518f608c9222c692c11d/src/sqlancer/duckdb/gen/DuckDBInsertGenerator.java#L38) class, whose expected errors are specified in [DuckDBErrors](https://github.com/sqlancer/sqlancer/blob/aa0c0eccba4eefa75bfd518f608c9222c692c11d/src/sqlancer/duckdb/DuckDBErrors.java#L90). When implementing such a generator, the list of expected errors might first be empty. When running the generator for the first time, you might receive an error such as "create unique index, table contains duplicate data", indicating that creating the index failed due to duplicate data. In principle, this error could be avoided by first checking whether the column contains any duplicate values. However, checking this would be expensive and error-prone (e.g., consider string similarity, which might depend on collations); thus, the obvious choice would be to add this string to the list of expected errors, and run the generator again to check for any other expected errors. In other cases, errors might be best addressed through improvements in the generators. For example, it is typically straightforward to generate syntactically-valid statements, which is why syntax errors should not be ignored. This approach is effective in uncovering internal errors; rather than ignoring them as an expected error, report them, and see [Unfixed Bugs](#unfixed-bugs) below. + +### Bailing Out While Generating a Statement + +In some cases, it might be undesirable or even impossible to generate a specific statement type. For example, consider that SQLancer tries to execute a `DROP TABLE` statement (e.g., see [TiDBDropTableGenerator](https://github.com/sqlancer/sqlancer/blob/30948f34acc2354d6be18a70bdeeebff1e73fa48/src/sqlancer/tidb/gen/TiDBDropTableGenerator.java)), but the database contains only a single table. Dropping the table would result in all subsequent attempts to insert data or query it to fail. Thus, in such a case, it might be more efficient to "bail out" by abandoning the current attempt to generate the statement. This can be achieved by throwing a `IgnoreMeException`. Unlike for other exceptions, SQLancer silently continues execution rather than reporting this exception to the user. + ### Typed vs. Untyped Expression Generation @@ -30,6 +46,47 @@ For a permissive DBMS, implementing the expression generator is easier, since th For a strict DBMS, the better approach is typically to attempt to generate expressions of the expected type. For PostgreSQL, the expression generator thus expects an additional type argument (see [PostgreSQLExpressionGenerator](https://github.com/sqlancer/sqlancer/blob/86647df8aa2dd8d167b5c3ce3297290f5b0b2bcd/src/sqlancer/postgres/gen/PostgresExpressionGenerator.java#L251)). This type is propagated recursively. For example, if we require a predicate for the `WHERE` clause, we pass boolean as a type. The expression generator then calls a method `generateBooleanExpression` that attempts to produce a boolean expression, by, for example, generating a comparison (e.g., `<=`). For the comparison's operands, a random type is then selected and propagated. For example, if an integer type is selected, then `generateExpression` is called with this type once for the left operand, and once for the right operand. Note that this process does not guarantee that the expression will indeed have the expected type. It might happen, for example, that the expression generator attempts to produce an integer value, but that it produces a double value instead, namely when an integer overflow occurs, which, depending on the DBMS, implicitly converts the result to a floating-point value. +#### Supported DBMS + +Since SQL dialects differ widely, each DBMS to be tested requires a separate implementation. + +| DBMS | Status | Expression Generation | Description | +| ---------------------------- | ----------- | ---------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| SQLite | Working | Untyped | This implementation is currently affected by a significant performance regression that still needs to be investigated | +| MySQL | Working | Untyped | Running this implementation likely uncovers additional, unreported bugs. | +| PostgreSQL | Working | Typed | | +| Citus (PostgreSQL Extension) | Working | Typed | This implementation extends the PostgreSQL implementation of SQLancer, and was contributed by the Citus team. | +| MariaDB | Preliminary | Untyped | The implementation of this DBMS is very preliminary, since we stopped extending it after all but one of our bug reports were addressed. Running it likely uncovers additional, unreported bugs. | +| CockroachDB | Working | Typed | | +| TiDB | Working | Untyped | | +| DuckDB | Working | Untyped, Generic | | +| ClickHouse | Preliminary | Untyped, Generic | Implementing the different table engines was not convenient, which is why only a very preliminary implementation exists. | +| TDEngine | Removed | Untyped | We removed the TDEngine implementation since all but one of our bug reports were still unaddressed five months after we reported them. | +| OceanBase | Working | Untyped | | +| YugabyteDB | Working | Typed (YSQL), Untyped (YCQL) | YSQL implementation based on Postgres code. YCQL implementation is primitive for now and uses Cassandra JDBC driver as a proxy interface. | +| Databend | Working | Typed | | +| QuestDB | Working | Untyped, Generic | The implementation of QuestDB is still WIP, current version covers very basic data types, operations and SQL keywords. | +| CnosDB | Working | Typed | The implementation of CnosDB currently uses Restful API. | +| Materialize | Working | Typed | | +| Apache Doris | Preliminary | Typed | This is a preliminary implementation, which only contains the common logic of Doris. We have found some errors through it, and hope to improve it in the future. | +| Presto | Preliminary | Typed | This is a preliminary implementation, only basic types supported. | +| DataFusion | Preliminary | Typed | Only basic SQL features are supported. | + +#### Previously Supported DBMS + +Some DBMS were once supported but subsequently removed. + +| DBMS | Pull Request | Description | +| ---------- | ----------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| ArangoDB | [#915](https://github.com/sqlancer/sqlancer/pull/915) | This implementation was removed because ArangoDB is a NoSQL DBMS, while the majority were SQL DBMSs, which resulted in difficulty refactoring SQLancer. | +| Cosmos | [#915](https://github.com/sqlancer/sqlancer/pull/915) | This implementation was removed because Cosmos is a NoSQL DBMS, while the majority were SQL DBMSs, which resulted in difficulty refactoring SQLancer. | +| MongoDB | [#915](https://github.com/sqlancer/sqlancer/pull/915) | This implementation was removed because MongoDB is a NoSQL DBMS, while the majority were SQL DBMSs, which resulted in difficulty refactoring SQLancer. | +| StoneDB | [#963](https://github.com/sqlancer/sqlancer/pull/963) | This implementation was removed because development of StoneDB stopped. + +### Unfixed Bugs + +Often, some bugs are fixed only after an extended period, meaning that SQLancer will repeatedly report the same bug. In such cases, it might be possible to avoid generating the problematic pattern, or adding an expected error with the internal error message. Rather than, for example, commenting out the code with the bug-inducing pattern, a pattern implemented by the [TiDBBugs class](https://github.com/sqlancer/sqlancer/blob/4c20a94b3ad2c037e1a66c0b637184f8c20faa7e/src/sqlancer/tidb/TiDBBugs.java) should be applied. The core idea is to use a public, static flag for each issue, which is set to true as long as the issue persists (e.g., see [bug35652](https://github.com/sqlancer/sqlancer/blob/4c20a94b3ad2c037e1a66c0b637184f8c20faa7e/src/sqlancer/tidb/TiDBBugs.java#L55)). The work-around code is then executed—or the problematic pattern should not be generated—if the flag is set to true (e.g., [an expected error is added for bug35652](https://github.com/sqlancer/sqlancer/blob/59564d818d991d54b32fa5a79c9f733799c090f2/src/sqlancer/tidb/TiDBErrors.java#L47)). This makes it easy to later on identify and remove all such work-around code once the issue has been fixed. + ## Options SQLancer uses [JCommander](https://jcommander.org/) for handling options. The `MainOptions` class contains options that are expected to be supported by all DBMS-testing implementations. Furthermore, each `*Provider` class provides a method to return an additional set of supported options. @@ -50,12 +107,12 @@ You can run them using the following command: mvn verify ``` -We use [Travis-CI](https://travis-ci.com/) to automatically check PRs. +We use [GitHub Actions](https://github.com/sqlancer/sqlancer/blob/master/.github/workflows/main.yml) to automatically check PRs. ## Testing -As part of the Travis-CI gate, we use smoke testing by running SQLancer on each supported DBMS for some minutes, to test that nothing is obviously broken. For DBMS for which all bugs have been fixed, we verify that SQLancer cannot find any further bugs (i.e., the return code is zero). +As part of the GitHub Actions check, we use smoke testing by running SQLancer on each supported DBMS for some minutes, to test that nothing is obviously broken. For DBMS for which all bugs have been fixed, we verify that SQLancer cannot find any further bugs (i.e., the return code is zero). In addition, we use [unit tests](https://github.com/sqlancer/sqlancer/tree/master/test/sqlancer) to test SQLancer's core functionality, such as random string and number generation as well as option passing. When fixing a bug, add a unit test, if it is easily possible. @@ -65,13 +122,13 @@ You can run the tests using the following command: mvn test ``` -Note that per default, the smoke testing is performed only for embedded DBMS (i.e., DuckDB and SQLite). To run smoke tests also for the other DBMS, you need to set environment variables. For example, you can run the MySQL smoke testing (and no other tests) using the following command: +Note that per default, the smoke testing is performed only for embedded DBMS (e.g., DuckDB and SQLite). To run smoke tests also for the other DBMS, you need to set environment variables. For example, you can run the MySQL smoke testing (and no other tests) using the following command: ``` MYSQL_AVAILABLE=true mvn -Dtest=TestMySQL test ``` -For up-to-date testing commands, check out the `.travis.yml` file. +For up-to-date testing commands, check out the `.github/workflows/main.yml` file. ## Reviewing @@ -88,4 +145,5 @@ Please pay attention to good commit messages (in particular subject lines). As b 2. Do not end the subject line with a period. For example, write "Refactor the handling of indexes" rather than "Refactor the handling of indexes.". 3. Use the imperative mood in the subject line. For example, write "Refactor the handling of indexes" rather than "Refactoring" or "Refactor**ed** the handling of indexes". -Please also pay attention to a clean commit history. Rather than merging with the main branch, use `git rebase` to rebase your commits on the main branch. Sometimes, it might happen that you discover an issue only after having already created a commit, for example, when an issue is found by `mvn verify` in the Travis CI. Do not introduce a separate commit for such issues. If the issue was introduced by the last commit, you can fix the issue, and use `git commit --amend` to change the latest commit. If the change was introduced by one of the previous commits, you can use `git rebase -i` to change the respective commit. If you already have a number of such commits, you can use `git squash` to "collapse" multiple commits into one. For more information, you might want to read [How (and Why!) to Keep Your Git Commit History Clean](https://about.gitlab.com/blog/2018/06/07/keeping-git-commit-history-clean/) written by Kushal Pandya. +Please also pay attention to a clean commit history. Rather than merging with the main branch, use `git rebase` to rebase your commits on the main branch. Sometimes, it might happen that you discover an issue only after having already created a commit, for example, when an issue is found by `mvn verify` in the CI checks. Do not introduce a separate commit for such issues. If the issue was introduced by the last commit, you can fix the issue, and use `git commit --amend` to change the latest commit. If the change was introduced by one of the previous commits, you can use `git rebase -i` to change the respective commit. If you already have a number of such commits, you can use `git squash` to "collapse" multiple commits into one. For more information, you might want to read [How (and Why!) to Keep Your Git Commit History Clean](https://about.gitlab.com/blog/2018/06/07/keeping-git-commit-history-clean/) written by Kushal Pandya. + diff --git a/README.md b/README.md index 2209e7589..134f47666 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,23 @@ [![Build Status](https://github.com/sqlancer/sqlancer/workflows/ci/badge.svg)](https://github.com/sqlancer/sqlancer/actions) -[![Twitter](https://img.shields.io/twitter/follow/sqlancer_dbms?style=social)](https://twitter.com/sqlancer_dbms) -# SQLancer ![SQLancer](media/logo/png/sqlancer_logo_logo_pos_500.png) -SQLancer (Synthesized Query Lancer) is a tool to automatically test Database Management Systems (DBMS) in order to find logic bugs in their implementation. We refer to logic bugs as those bugs that cause the DBMS to fetch an incorrect result set (e.g., by omitting a record). +SQLancer is a tool to automatically test Database Management Systems (DBMSs) in order to find bugs in their implementation. That is, it finds bugs in the code of the DBMS implementation, rather than in queries written by the user. SQLancer has found hundreds of bugs in mature and widely-known DBMSs. -SQLancer operates in the following two phases: +SQLancer tackles two essential challenges when automatically testing the DBMSs: +1. **Test input generation**: SQLancer implements approaches for automatically generating SQL statements. It contains various hand-written SQL generators that operate in multiple phases. First, a database schema is created, which refers to a set of tables and their columns. Then, data is inserted into these tables, along with creating various other kinds of database states such as indexes, views, or database-specific options. Finally, queries are generated, which can be validated using one of multiple result validators (also called *test oracles*) that SQLancer provides. Besides the standard approach of creating the statements in an unguided way, SQLancer also supports a test input-generation approach that is feedback-guided and aims to exercise as many unique query plans as possible based on the intuition that doing so would exercise many interesting behaviors in the database system [[ICSE '23]](https://arxiv.org/pdf/2312.17510). +2. **Test oracles**: A key innovation in SQLancer is that it provides ways to find deep kinds of bugs in DBMSs. As a main focus, it can find logic bugs, which are bugs that cause the DBMS to fetch an incorrect result set (e.g., by omitting a record). We have proposed multiple complementary test oracles such as *Ternary Logic Partitioning (TLP)* [[OOPSLA '20]](https://dl.acm.org/doi/pdf/10.1145/3428279), *Non-optimizing Reference Engine Construction (NoREC)* [[ESEC/FSE 2020]](https://arxiv.org/abs/2007.08292), *Pivoted Query Synthesis (PQS)* [[OSDI '20]](https://www.usenix.org/system/files/osdi20-rigger.pdf), *Differential Query Plans (DQP)* [[SIGMOD '24]](https://dl.acm.org/doi/pdf/10.1145/3654991), and *Constant Optimization Driven Database System Testing (CODDTest)* [SIGMOD '25]. It can also find specific categories of performance issues, which refer to cases where a DBMS could reasonably be expected to produce its result more efficiently using a technique called *Cardinality Estimation Restriction Testing (CERT)* [[ICSE '24]](https://arxiv.org/pdf/2306.00355). SQLancer can detect unexpected internal errors (e.g., an error that the database is corrupted) by declaring all potential errors that might be returned by a DBMS for a given query. Finally, SQLancer can find crash bugs, which are bugs that cause the DBMS process to terminate. For this, it uses an implicit test oracle. -1. Database generation: The goal of this phase is to create a populated database, and stress the DBMS to increase the probability of causing an inconsistent database state that could be detected subsequently. First, random tables are created. Then, randomly SQL statements are chosen to generate, modify, and delete data. Also other statements, such as those to create indexes as well as views and to set DBMS-specific options are sent to the DBMS. -2. Testing: The goal of this phase is to detect the logic bugs based on the generated database. See Testing Approaches below. +**Community.** We have a [Slack workspace](https://join.slack.com/t/sqlancer/shared_invite/zt-eozrcao4-ieG29w1LNaBDMF7OB_~ACg) to discuss SQLancer, and DBMS testing in general. Previously, SQLancer had an account on Twitter/X [@sqlancer_dbms](https://twitter.com/sqlancer_dbms), which is no longer maintained. We have a [blog](https://sqlancer.github.io/posts/), which, as of now, contains only posts by contributors of the [Google Summer of Code project](https://summerofcode.withgoogle.com/archive/2023/organizations/sqlancer). -# Getting Started +# Getting Started [[Video Guide]](https://www.youtube.com/watch?v=lcZ6LixPH1Y) -Requirements: +Minimum Requirements: * Java 11 or above -* [Maven](https://maven.apache.org/) (`sudo apt install maven` on Ubuntu) -* The DBMS that you want to test (embedded DBMSs such as DuckDB, H2, and SQLite do not require a setup) +* [Maven](https://maven.apache.org/) -The following commands clone SQLancer, create a JAR, and start SQLancer to test SQLite using Non-optimizing Reference Engine Construction (NoREC): +The following commands clone SQLancer, create a JAR, and start SQLancer to test SQLite using [Non-optimizing Reference Engine Construction (NoREC)](https://arxiv.org/abs/2007.08292): ``` git clone https://github.com/sqlancer/sqlancer @@ -29,74 +27,90 @@ cd target java -jar sqlancer-*.jar --num-threads 4 sqlite3 --oracle NoREC ``` -If the execution prints progress information every five seconds, then the tool works as expected. Note that SQLancer might find bugs in SQLite. Before reporting these, be sure to check that they can still be reproduced when using the latest development version. The shortcut CTRL+C can be used to terminate SQLancer manually. If SQLancer does not find any bugs, it executes infinitely. The option `--num-tries` can be used to control after how many bugs SQLancer terminates. Alternatively, the option `--timeout-seconds` can be used to specify the maximum duration that SQLancer is allowed to run. +**Running and terminating.** If the execution prints progress information every five seconds, then the tool works as expected. The shortcut CTRL+C can be used to terminate SQLancer manually. If SQLancer does not find any bugs, it executes infinitely. The option `--num-tries` can be used to control after how many bugs SQLancer terminates. Alternatively, the option `--timeout-seconds` can be used to specify the maximum duration that SQLancer is allowed to run. -If you launch SQLancer without parameters, available options and commands are displayed. Note that general options that are supported by all DBMS-testing implementations (e.g., `--num-threads`) need to precede the name of DBMS to be tested (e.g., `sqlite3`). Options that are supported only for specific DBMS (e.g., `--test-rtree` for SQLite3), or options for which each testing implementation provides different values (e.g. `--oracle NoREC`) need to go after the DBMS name. +**Parameters.** If you launch SQLancer without parameters, available options and commands are displayed. Note that general options that are supported by all DBMS-testing implementations (e.g., `--num-threads`) need to precede the name of the DBMS to be tested (e.g., `sqlite3`). Options that are supported only for specific DBMS (e.g., `--test-rtree` for SQLite3), or options for which each testing implementation provides different values (e.g. `--oracle NoREC`) need to go after the DBMS name. -# Testing Approaches +**DBMSs.** To run SQLancer on SQLite, it was not necessary to install and set up a DBMS. The reason for this is that embedded DBMSs run in the same process as the application and thus require no separate installation or setup. Embedded DBMSs supported by SQLancer include DuckDB, H2, and SQLite. Their binaries are included as [JAR dependencies](https://github.com/sqlancer/sqlancer/blob/main/pom.xml). Note that any crashes in these systems will also cause a crash in the JVM on which SQLancer runs. -| Approach | Description | -|------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Pivoted Query Synthesis (PQS) | PQS is the first technique that we designed and implemented. It randomly selects a row, called a pivot row, for which a query is generated that is guaranteed to fetch the row. If the row is not contained in the result set, a bug has been detected. It is fully described [here](https://arxiv.org/abs/2001.04174). PQS is the most powerful technique, but also requires more implementation effort than the other two techniques. It is currently unmaintained. | -| Non-optimizing Reference Engine Construction (NoREC) | NoREC aims to find optimization bugs. It is described [here](https://www.manuelrigger.at/preprints/NoREC.pdf). It translates a query that is potentially optimized by the DBMS to one for which hardly any optimizations are applicable, and compares the two result sets. A mismatch between the result sets indicates a bug in the DBMS. | -| Ternary Logic Partitioning (TLP) | TLP partitions a query into three partitioning queries, whose results are composed and compare to the original query's result set. A mismatch in the result sets indicates a bug in the DBMS. In contrast to NoREC and PQS, it can detect bugs in advanced features such as aggregate functions. | -Please find the `.bib` entries [here](docs/PAPERS.md). +# Using SQLancer -# Supported DBMS +**Logs.** SQLancer stores logs in the `target/logs` subdirectory. By default, the option `--log-each-select` is enabled, which results in every SQL statement that is sent to the DBMS being logged. The corresponding file names are postfixed with `-cur.log`. In addition, if SQLancer detects a logic bug, it creates a file with the extension `.log`, in which the statements to reproduce the bug are logged, including only the last query that was executed along with the other statements to set up the database state. -Since SQL dialects differ widely, each DBMS to be tested requires a separate implementation. +**Reducing bugs.** After finding a bug-inducing test input, the input typically needs to be reduced to be further analyzed, as it might contain many SQL statements that are redundant to reproduce the bug. One option is to do this manually, by removing a statement or feature at a time, replaying the bug-inducing statements, and applying the test oracle (e.g., for test oracles like TLP or NoREC, this would require checking that both queries still produce a different result). This process can be automated using a so-called [delta-debugging approach](https://www.debuggingbook.org/html/DeltaDebugger.html). SQLancer includes an experimental implementation of a delta debugging approach, which can be enabled using `--use-reducer`. In the past, we have successfully used [C-Reduce](https://embed.cs.utah.edu/creduce/), which requires specifying the test oracle in a script that can be executed by C-Reduce. -| DBMS | Status | Expression Generation | Description | -|-------------|-------------|-----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| SQLite | Working | Untyped | This implementation is currently affected by a significant performance regression that still needs to be investigated | -| MySQL | Working | Untyped | Running this implementation likely uncovers additional, unreported bugs. | -| PostgreSQL | Working | Typed | | -| Citus (PostgreSQL Extension) | Working | Typed | This implementation extends the PostgreSQL implementation of SQLancer, and was contributed by the Citus team. | -| MariaDB | Preliminary | Untyped | The implementation of this DBMS is very preliminary, since we stopped extending it after all but one of our bug reports were addressed. Running it likely uncovers additional, unreported bugs. | -| CockroachDB | Working | Typed | | -| TiDB | Working | Untyped | | -| DuckDB | Working | Untyped, Generic | | -| ClickHouse | Preliminary | Untyped, Generic | Implementing the different table engines was not convenient, which is why only a very preliminary implementation exists. | -| TDEngine | Removed | Untyped | We removed the TDEngine implementation since all but one of our bug reports were still unaddressed five months after we reported them. | -| OceanBase | Working | Untyped | | +**Testing the latest DBMS version.** For most DBMSs, SQLancer supports only a previous *release* version. Thus, potential bugs that SQLancer finds could be already fixed in the latest *development* version of the DBMS. If you are not a developer of the DBMS that you are testing, we would like to encourage you to validate that the bug can still be reproduced before reporting it. We would appreciate it if you could mention SQLancer when you report bugs found by it. We would also be excited to hear about your experience using SQLancer or related use cases or extensions. +**Options.** SQLancer provides many options that you can use to customize its behavior. Executing `java -jar sqlancer-*.jar --help` will list them and should print output such as the following: +``` +Usage: SQLancer [options] [command] [command options] + Options: + --ast-reducer-max-steps + EXPERIMENTAL Maximum steps the AST-based reducer will do + Default: -1 + --ast-reducer-max-time + EXPERIMENTAL Maximum time duration (secs) the statement reducer will do + Default: -1 + --canonicalize-sql-strings + Should canonicalize query string (add ';' at the end + Default: true + --constant-cache-size + Specifies the size of the constant cache. This option only takes effect + when constant caching is enabled + Default: 100 +... +``` -# Using SQLancer +**Which SQLancer version to use.** The recommended way to use SQLancer is to use its latest source version on GitHub. Infrequent and irregular official releases are also available on the following platforms: +* [GitHub](https://github.com/sqlancer/sqlancer/releases) +* [Maven Central](https://search.maven.org/artifact/com.sqlancer/sqlancer) +* [DockerHub](https://hub.docker.com/r/mrigger/sqlancer) + +**Understanding SQL generation.** To analyze bug-inducing statements, it is helpful to understand the characteristics of SQLancer. First, SQLancer is expected to always generate SQL statements that are syntactically valid for the DBMS under test. Thus, you should never observe any syntax errors. Second, SQLancer might generate statements that are semantically invalid. For example, SQLancer might attempt to insert duplicate values into a column with a `UNIQUE` constraint, as completely avoiding such semantic errors is challenging. Third, any bug reported by SQLancer is expected to be a real bug, except those reported by CERT (as performance issues are not as clearly defined as other kinds of bugs). If you observe any bugs indicated by SQLancer that you do not consider bugs, something is likely wrong with your setup. Finally, related to the aforementioned point, SQLancer is specific to a version of the DBMS, and you can find the version against which we are tested in our [GitHub Actions workflow](https://github.com/sqlancer/sqlancer/blob/documentation/.github/workflows/main.yml). If you are testing against another version, you might observe various false alarms (e.g., caused by syntax errors). While we would always like for SQLancer to be up-to-date with the latest development version of each DBMS, we lack the resources to achieve this. -## Logs +**Supported DBMSs.** SQLancer requires DBMS-specific code for each DBMS that it supports. As of January 2025, it provides support for Citus, ClickHouse, CnosDB, CockroachDB, Databend, (Apache) DataFusion, (Apache) Doris, DuckDB, H2, HSQLDB, MariaDB, Materialize, MySQL, OceanBase, PostgreSQL, Presto, QuestDB, SQLite3, TiDB, and YugabyteDB. The extent to which the individual DBMSs are supported [differs](https://github.com/sqlancer/sqlancer/blob/documentation-approaches/CONTRIBUTING.md). -SQLancer stores logs in the `target/logs` subdirectory. By default, the option `--log-each-select` is enabled, which results in every SQL statement that is sent to the DBMS being logged. The corresponding file names are postfixed with `-cur.log`. In addition, if SQLancer detects a logic bug, it creates a file with the extension `.log`, in which the statements to reproduce the bug are logged. +# Approaches and Papers -## Reducing a Bug +SQLancer has pioneered and includes multiple approaches for DBMS testing, as outlined below in chronological order. -After finding a bug, it is useful to produce a minimal test case before reporting the bug, to save the DBMS developers' time and effort. For many test cases, [C-Reduce](https://embed.cs.utah.edu/creduce/) does a great job. In addition, we have been working on a SQL-specific reducer, which we plan to release soon. +| Technique | Venue | Links | Description | +|-----------------------------------------------------------------|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Pivoted Query Synthesis (PQS) | OSDI 2020 | [Paper](https://www.usenix.org/system/files/osdi20-rigger.pdf) [Video](https://www.youtube.com/watch?v=0aeDyXgzo04 ) | PQS is the first technique that we designed and implemented. It randomly selects a row, called a pivot row, for which a query is generated that is guaranteed to fetch the row. If the row is not contained in the result set, a bug has been detected. It is fully described here. PQS effectively detects bugs, but requires more implementation effort than other testing approaches that follow a metamorphic testing or differential testing methodology. Thus, it is currently unmaintained. | +| Non-optimizing Reference Engine Construction (NoREC) | ESEC/FSE 2020 | [Paper](https://arxiv.org/abs/2007.08292) [Video](https://www.youtube.com/watch?v=4mbzytrWJhQ) | NoREC aims to find optimization bugs. It translates a query that is potentially optimized by the DBMS to one for which hardly any optimizations are applicable, and compares the two result sets. A mismatch between the result sets indicates a bug in the DBMS. The approach applies primarily to simple queries with a filter predicate. | +| Ternary Logic Partitioning (TLP) | OOPSLA 2020 | [Paper](https://dl.acm.org/doi/pdf/10.1145/3428279) [Video](https://www.youtube.com/watch?v=FN9OLbGh0VI) | TLP partitions a query into three partitioning queries, whose results are composed and compared to the original query's result set. A mismatch in the result sets indicates a bug in the DBMS. In contrast to NoREC and PQS, it can detect bugs in advanced features such as aggregate functions. It is among the most widely adopted testing techniques. | +| Differential Query Execution (DQE) | ICSE 2023 | [Paper](https://ieeexplore.ieee.org/document/10172736) [Code](https://github.com/sqlancer/sqlancer/pull/1251) | Differential Query Execution (DQE) is a novel and general approach to detect logic bugs in SELECT, UPDATE and DELETE queries. DQE solves the test oracle problem by executing SELECT, UPDATE and DELETE queries with the same predicate φ, and observing inconsistencies among their execution results. For example, if a row that is updated by an UPDATE query with a predicate φ does not appear in the query result of a SELECT query with the same predicate φ, a logic bug is detected in the target DBMS. We append two extra columns to each table in a database to uniquely identify each row and track whether a row has been modified. We further rewrite SELECT and UPDATE queries to identify their accessed rows. DQE supports MySQL. | +| Query Plan Guidance (QPG) | ICSE 2023 | [Paper](https://arxiv.org/pdf/2312.17510) [Video](https://youtu.be/6EjQ1cKiZJU?si=gh7uoykRqNjl3GXR&t=1820) [Code](https://github.com/sqlancer/sqlancer/issues/641) | QPG is a feedback-guided test case generation approach. It is based on the insights that query plans capture whether interesting behavior is exercised within the DBMS. It works by mutating the database state when no new query plans have been observed after executing a number of queries, expecting that the new state enables new query plans to be triggered. This approach is enabled by option `--qpg-enable` and supports TLP and NoREC oracles for SQLite, CockroachDB, TiDB, and Materialize. It is the only approach that specifically tackles the test input generation problem. | +| Cardinality Estimation Restriction Testing (CERT) | ICSE 2024 | [Paper](https://arxiv.org/pdf/2306.00355) [Code](https://github.com/sqlancer/sqlancer/issues/822) | CERT aims to find performance issues through unexpected estimated cardinalities, which represent the estimated number of returned rows. From a given input query, it derives a more restrictive query, whose estimated cardinality should be no more than that of the original query. A violation indicates a potential performance issue. CERT supports TiDB, CockroachDB, and MySQL. CERT is the only test oracle that is part of SQLancer that was designed to find performance issues. | +| Differential Query Plans (DQP) | SIGMOD 2024 | [Paper](https://dl.acm.org/doi/pdf/10.1145/3654991) [Video](https://www.youtube.com/watch?v=9Qp7quJfGEk) [Code](https://github.com/sqlancer/sqlancer/issues/918) | DQP aims to find logic bugs by controlling the execution of different query plans for a given query and validating that they produce a consistent result. DQP supports MySQL, MariaDB, and TiDB. | +| Constant Optimization Driven Database System Testing (CODDTest) | SIGMOD 2025 | [Code](https://github.com/sqlancer/sqlancer/pull/1054) | CODDTest finds logic bugs in DBMSs, including in advanced features such as subqueries. It is based on the insight that we can assume the database state to be constant for a database session, which then enables us to substitute parts of a query with their results, essentially corresponding to constant folding and constant propagation, which are two traditional compiler optimizations. | -## Found Bugs +Please find the `.bib` entries [here](docs/PAPERS.md). | -We would appreciate it if you mention SQLancer when you report bugs found by it. We would also be excited to know if you are using SQLancer to find bugs, or if you have extended it to test another DBMS (also if you do not plan to contribute it to this project). SQLancer has found over 400 bugs in widely-used DBMS, which are listed [here](https://www.manuelrigger.at/dbms-bugs/). +# FAQ +**I am running SQLancer on the latest version of a supported DBMS. Is it expected that SQLancer prints many AssertionErrors?** In many cases, SQLancer does not support the latest version of a DBMS. You can check the [`.github/workflows/main.yml`](https://github.com/sqlancer/sqlancer/blob/master/.github/workflows/main.yml) file to determine which version we use in our CI tests, which corresponds to the currently supported version of that DBMS. SQLancer should print only an `AssertionError` and produce a corresponding log file, if it has identified a bug. To upgrade SQLancer to support a new DBMS version, either two options are advisable: (1) the generators can be updated to no longer generate certain patterns that might cause errors (e.g., which might be the case if a keyword or option is no longer supported) or (2) the newly-appearing errors can be added as [expected errors](https://github.com/sqlancer/sqlancer/blob/354d591cfcd37fa1de85ec77ec933d5d975e947a/src/sqlancer/common/query/ExpectedErrors.java) so that SQLancer ignores them when they appear (e.g., this is useful if some error-inducing patterns cannot easily be avoided). -# Community +Another reason for many failures on a supported version could be that error messages are printed in a non-English locale (which would then be visible in the stack trace). In such a case, try setting the DBMS' locale to English (e.g., see the [PostgreSQL homepage](https://www.postgresql.org/docs/current/locale.html)). -We have created a [Slack workspace](https://join.slack.com/t/sqlancer/shared_invite/zt-eozrcao4-ieG29w1LNaBDMF7OB_~ACg) to discuss SQLancer, and DBMS testing in general. SQLancer's official Twitter handle is [@sqlancer_dbms](https://twitter.com/sqlancer_dbms). +**When starting SQLancer, I get an error such as "database 'test' does not exist". How can I run SQLancer without this error?** For some DBMSs, SQLancer expects that a database "test" exists, which it then uses as an initial database to connect to. If you have not yet created such a database, you can use a command such as `CREATE DATABASE test` to create this database (e.g., see the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-createdatabase.html)). +# Links -# Additional Documentation +Documentation and resources: * [Contributing to SQLancer](CONTRIBUTING.md) * [Papers and .bib entries](docs/PAPERS.md) +* More information on our DBMS testing efforts and the bugs we found is available [here](https://www.manuelrigger.at/dbms-bugs/). -# Releases - -Official release are available on: -* [GitHub](https://github.com/sqlancer/sqlancer/releases) -* [Maven Central](https://search.maven.org/artifact/com.sqlancer/sqlancer) -* [DockerHub](https://hub.docker.com/r/mrigger/sqlancer) - -# Additional Resources +Videos: +* [SQLancer Tutorial Playlist](https://www.youtube.com/playlist?list=PLm7ofmclym1E2LwBeSer_AAhzBSxBYDci) +* [SQLancer Talks](https://youtube.com/playlist?list=PLm7ofmclym1E9-AbYy-PkrMfHpB9VdlZJ) -* A talk on Ternary Logic Partitioning (TLP) and SQLancer is available on [YouTube](https://www.youtube.com/watch?v=Np46NQ6lqP8). -* An (older) Pivoted Query Synthesis (PQS) talk is available on [YouTube](https://www.youtube.com/watch?v=yzENTaWe7qg). -* PingCAP has implemented PQS, NoREC, and TLP in a tool called [go-sqlancer](https://github.com/chaos-mesh/go-sqlancer). -* More information on our DBMS testing efforts and the bugs we found is available [here](https://www.manuelrigger.at/dbms-bugs/). +Closely related tools: +* [go-sqlancer](https://github.com/chaos-mesh/go-sqlancer): re-implementation of some of SQLancer's approaches in Go by PingCAP +* [Jepsen](https://github.com/jepsen-io): testing of distributed (database) systems +* [SQLRight](https://github.com/PSU-Security-Universe/sqlright): coverage-guided DBMS fuzzer, also supporting NoREC and TLP +* [SQLsmith](https://github.com/anse1/sqlsmith): random SQL query generator used for fuzzing +* [Squirrel](https://github.com/s3team/Squirrel): coverage-guided DBMS fuzzer diff --git a/configs/spotbugs-exclude.xml b/configs/spotbugs-exclude.xml index 366e8ac55..7fa4de560 100644 --- a/configs/spotbugs-exclude.xml +++ b/configs/spotbugs-exclude.xml @@ -7,9 +7,19 @@ + + + + + - + + + + + + diff --git a/docs/PAPERS.md b/docs/PAPERS.md index ca9d40de7..a42b42c12 100644 --- a/docs/PAPERS.md +++ b/docs/PAPERS.md @@ -1,6 +1,6 @@ # Papers -The testing approaches implemented in SQLancer are described in the three papers below. +The testing approaches implemented in SQLancer are described in the four papers below. ## Testing Database Engines via Pivoted Query Synthesis @@ -51,6 +51,51 @@ This paper describes TLP, a metamorphic testing approach that can detect various } ``` +## Testing Database Engines via Query Plan Guidance + +This paper describes Query Plan Guidance (QPG), a test case generation method guided by query plan coverage. This method can be paired with above three testing methods. A preprint is available [here](http://bajinsheng.github.io/assets/pdf/qpg_icse23.pdf). + +``` +@inproceedings{Ba2023QPG, + author = {Ba, Jinsheng and Rigger, Manuel}, + title = {Testing Database Engines via Query Plan Guidance}, + booktitle = {The 45th International Conference on Software Engineering (ICSE'23)}, + year = {2023}, + month = may +} +``` + +## CERT: Finding Performance Issues in Database Systems Through the Lens of Cardinality Estimation + +This paper describes CERT, a testing approach to find performance issues by inspecting inconsistent estimated cardinalities. A preprint is available [here](https://bajinsheng.github.io/assets/pdf/cert_icse24.pdf). + +``` +@inproceedings{cert, + author = {Ba, Jinsheng and Rigger, Manuel}, + title = {CERT: Finding Performance Issues in Database Systems Through the Lens of Cardinality Estimation}, + booktitle = {The 46th International Conference on Software Engineering (ICSE'24)}, + year = {2024}, + month = apr, +} +``` + +## Keep It Simple: Testing Databases via Differential Query Plans + +This paper describes DQP, a testing approach to find logic bugs in database systems by comparing the query plans of different database systems. A preprint is available [here](https://bajinsheng.github.io/assets/pdf/dqp_sigmod24.pdf). + +``` +@article{dqp, + author = {Ba, Jinsheng and Rigger, Manuel}, + title = {Keep It Simple: Testing Databases via Differential Query Plans}, + year = {2024}, + issue_date = {June 2024}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + journal = {Proceeding of ACM Management of Data (SIGMOD'24)}, + month = jun +} +``` + # Comparing SQLancer With Other Tools that Find Logic Bugs If you want to fairly compare other tools with SQLancer, we would be glad to provide feedback (e.g., feel free to send an email to manuel.rigger@inf.ethz.ch). We have the following general recommendations and comments: diff --git a/docs/QueryPlanGuidance.md b/docs/QueryPlanGuidance.md new file mode 100644 index 000000000..bb467461b --- /dev/null +++ b/docs/QueryPlanGuidance.md @@ -0,0 +1,66 @@ +# Query Plan Guidance +Query Plan Guidance (QPG) is a test case generation method that attempts to explore unseen query plans. Given a database state, we mutate it after no new unique query plans have been observed by randomly-generated queries on the database state aiming to cover more unique query plans for exposing more logics of DBMSs. Here, we document all mutators in which we choose the most promising one that may help covering more unique query plans to execute. + +# Mutators +All mutators are listed below and implemented in the enumeration variables `Action` in the `XXDBProvider.java` file of each DBMS. +The `Mutator` column includes the items in the `Action` enumeration variable. +The `Example` column includes an example of a realistic statement generated by this mutator. +The `Description` column includes an explanation of what the mutator does. +The `More unique query plans...` column explains why applying this mutator may help covering more unique query plans. + + +|DBMS |Mutator |Example |Description |More unique query plans may be covered because of | +|-----------|---------------------|--------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------|--------------------------------------------------------| +|SQLite |PRAGMA |PRAGMA automatic_index true; |It modifies database options. |different options that decide how to execute statements.| +|SQLite |CREATE_INDEX |CREATE INDEX i0 ON t0 WHERE c0 ISNULL; |It adds a new index on a table. |subsequent differnt logic of querying data. | +|SQLite |CREATE_VIEW |CREATE VIEW v0(c0) AS SELECT DISTINCT ABS(t0.c2) FROM t0; |It adds a new view from existing tables. |more possible execution logics on the view. | +|SQLite |CREATE_TABLE |CREATE TABLE t0 (c0 INT CHECK ((c0) BETWEEN (1) AND (10)) ); |It adds a new table. |more possible execution logics on the table. | +|SQLite |CREATE_VIRTUALTABLE |CREATE VIRTUAL TABLE vt1 USING fts5(c0 UNINDEXED); |It adds a new table with fts5 feature. |more possible execution logics on the table with fts5. | +|SQLite |CREATE_RTREETABLE |CREATE VIRTUAL TABLE rt0 USING rtree_i32(c0, c1, c2, c3, c4); |It adds a new table with rtree feature. |more possible execution logics on the table with rtree. | +|SQLite |INSERT |INSERT INTO t0(c0, c1) VALUES ('lrd+a*', NULL); |It inserts a new row to a table. |subsequent different logic of querying data. | +|SQLite |DELETE |DELETE FROM t0 WHERE (c0>3); |It deletes specific rows from a table. |subsequent different logic of querying data. | +|SQLite |ALTER |ALTER TABLE t0 ADD COLUMN c39 REAL; |It changes the schema of a table. |more possible execution logics on the changed table. | +|SQLite |UPDATE |UPDATE t0 SET (c2, c0)=(-944, 'L((xA') WHERE t0.c1; |It updates specific data of a table. |subsequent different logic of querying data. | +|SQLite |DROP_INDEX |DROP INDEX i0; |It drops an index. |subsequent different logic of querying data. | +|SQLite |DROP_TABLE |DROP TABLE t0; |it drops an table. |subsequent different logic of querying data. | +|SQLite |DROP_VIEW |DROP VIEW v0; |It drops a view. |subsequent different logic of querying data. | +|SQLite |VACUUM |VACUUM main; |It rebuilds the database file. |subsequent different logic of querying data. | +|SQLite |REINDEX |REINDEX t0; |It drops and recreates indexes from scratch. |subsequent different logic of querying data. | +|SQLite |ANALYZE |ANALYZE t0; |It gathers statistics about tables to help make better query planning choices.|subsequent different logic of querying data. | +|SQLite |EXPLAIN |EXPLAIN SELECT * FROM t0; |It obtains query plan of a query. |subsequent different logic of querying data. | +|SQLite |CHECK_RTREE_TABLE |SELECT rtreecheck('rt0'); |It runs an integrity check on a table. |subsequent different logic of querying data. | +|SQLite |VIRTUAL_TABLE_ACTION |INSERT INTO vt0(vt0) VALUES('rebuild'); |It changes the options of a virtual table. |subsequent different logic of querying data. | +|SQLite |MANIPULATE_STAT_TABLE|INSERT INTO sqlite_stat1 VALUES('rt0', 't1', '2'); |It changes the table that stores statistics of all tables. |subsequent different logic of querying data. | +|SQLite |TRANSACTION_START |BEGIN TRANSACTION; |All statements after this will not be committed. |subsequent different logic of querying data. | +|SQLite |ROLLBACK_TRANSACTION |ROLLBACK TRANSACTION; |All statements after last BEGIN are dropped. |subsequent different logic of querying data. | +|SQLite |COMMIT |COMMIT; |All statements after last BEGIN are committed |subsequent different logic of querying data. | +|TiDB |CREATE_TABLE |CREATE TABLE t1(c0 INT); |It adds a new table. |more possible execution logics on the table. | +|TiDB |CREATE_INDEX |CREATE INDEX i0 ON t0(c0(250) ASC) KEY_BLOCK_SIZE 1564693810209727437; |It adds a new index on a table. |subsequent differnt logic of querying data. | +|TiDB |VIEW_GENERATOR |CREATE VIEW v0(c0, c1) AS SELECT t1.c0, ((t1.c0)REGEXP('8')) FROM t1; |It adds a new view from existing tables. |more possible execution logics on the view. | +|TiDB |INSERT |INSERT INTO t0(c0) VALUES (-16387); |It inserts a new row to a table. |subsequent different logic of querying data. | +|TiDB |ALTER_TABLE |ALTER TABLE t1 ADD PRIMARY KEY(c0); |It changes the schema of a table. |more possible execution logics on the changed table. | +|TiDB |TRUNCATE |TRUNCATE t0; |It drops all rows of a table. |subsequent different logic of querying data. | +|TiDB |UPDATE |UPDATE t0 SET c0='S' WHERE t0.c0; |It updates specific data of a table. |subsequent different logic of querying data. | +|TiDB |DELETE |DELETE FROM t0 ORDER BY CAST(t0.c0 AS CHAR) DESC; |It deletes specific rows from a table. |subsequent different logic of querying data. | +|TiDB |SET |set @@tidb_max_chunk_size=8864; |It modifies database options. |different options that decide how to execute statements.| +|TiDB |ADMIN_CHECKSUM_TABLE |ADMIN CHECKSUM TABLE t0; |it calculate the checksum for a table. |subsequent different logic of querying data. | +|TiDB |ANALYZE_TABLE |ANALYZE TABLE t1 WITH 174 BUCKETS; |It gathers statistics about tables to help make better query planning choices.|subsequent different logic of querying data. | +|TiDB |DROP_TABLE |DROP TABLE t0; |it drops an table. |subsequent different logic of querying data. | +|TiDB |DROP_VIEW |DROP VIEW v0; |It drops a view. |subsequent different logic of querying data. | +|CockroachDB|CREATE_TABLE |CREATE TABLE t1 (c0 INT4, c1 VARBIT(44) UNIQUE DEFAULT (B'000'), CONSTRAINT "primary" PRIMARY KEY(c1 ASC, c0 ASC));|It adds a new table. |more possible execution logics on the table. | +|CockroachDB|CREATE_INDEX |CREATE INDEX ON t0(rowid); |It adds a new index on a table. |subsequent differnt logic of querying data. | +|CockroachDB|CREATE_VIEW |CREATE VIEW v0(c0) AS SELECT DISTINCT MIN(TIMETZ '1970-01-11T12:19:44') FROM t0; |It adds a new view from existing tables. |more possible execution logics on the view. | +|CockroachDB|CREATE_STATISTICS |CREATE STATISTICS s0 FROM t2; |It gathers statistics about tables to help make better query planning choices.|subsequent different logic of querying data. | +|CockroachDB|INSERT |INSERT INTO t1 (rowid, c0) VALUES(NULL, true) ON CONFLICT (c0) DO NOTHING ; |It inserts a new row to a table. |subsequent different logic of querying data. | +|CockroachDB|UPDATE |UPDATE t0@{FORCE_INDEX=t0_pkey} SET c0=t0.c0; |It updates specific data of a table. |subsequent different logic of querying data. | +|CockroachDB|SET_SESSION |SET SESSION BYTEA_OUTPUT=escape; |It changes session configurations. |different options that decide how to execute statements.| +|CockroachDB|SET_CLUSTER_SETTING |SET CLUSTER SETTING sql.query_cache.enabled=true; |It changes cluster configurations. |different options that decide how to execute statements.| +|CockroachDB|DELETE |DELETE from t0; |It deletes specific rows from a table. |subsequent different logic of querying data. | +|CockroachDB|TRUNCATE |TRUNCATE TABLE t1 CASCADE; |It drops all rows of a table. |subsequent different logic of querying data. | +|CockroachDB|DROP_TABLE |DROP TABLE t0; |it drops an table. |subsequent different logic of querying data. | +|CockroachDB|DROP_VIEW |DROP VIEW v0; |It drops a view. |subsequent different logic of querying data. | +|CockroachDB|COMMENT_ON |COMMENT ON INDEX t0_c0_key IS '|?'; |It changes schema of a table. |subsequent different logic of querying data. | +|CockroachDB|SHOW |SHOW LOCALITY; |It lists detailed information of active queries. |subsequent different logic of querying data. | +|CockroachDB|EXPLAIN |EXPLAIN SELECT * FROM t0; |It obtains query plan of a query. |subsequent different logic of querying data. | +|CockroachDB|SCRUB |EXPERIMENTAL SCRUB table t0; |It checks data corruption of a table. |subsequent different logic of querying data. | +|CockroachDB|SPLIT |ALTER INDEX t0@t0_c0_key SPLIT AT VALUES (NULL); |It changes the indexes. |subsequent different logic of querying data. | diff --git a/docs/testCaseReduction.md b/docs/testCaseReduction.md new file mode 100644 index 000000000..ee317f791 --- /dev/null +++ b/docs/testCaseReduction.md @@ -0,0 +1,50 @@ +# Test Case Reduction +SQLancer generates a large number of statements, but not all of them are relevant to the bug. To automatically reduce the test cases, two reducers were implemented: the statement reducer and the AST-based reducer. + +## Statement Reducer +The statement reducer utilizes the delta-debugging technique to remove irrelevant statements. More details of delta-debugging could be found in this paper: [Simplifying and Isolating Failure-Inducing Input](https://www.cs.purdue.edu/homes/xyzhang/fall07/Papers/delta-debugging.pdf). + +Using the statement reducer, SQLancer reduces the set of statements to a minimal subset that reproduces the bug. + +## AST-Based Reducer +The AST-based reducer can shorten a statement by applying AST level transformations, including removing unnecessary clauses, irrelevant elements in a list, simplify complicated expressions and etc. + +The transformations are implemented by [JSQLParser](https://github.com/JSQLParser/JSqlParser), a RDBMS agnostic SQL statement parser that can translate SQL statements into a traversable hierarchy of Java classes. JSQLParser provides support for the SQL standard as well as major SQL dialects. The AST-based reducer works for any SQL dialects that can be parsed by this tool. + +## Implementing reproducer +Determining whether a bug persists after reducing statements +is an undecidable task for general transformations. +In practice, reducers use the [reproducer](../src/sqlancer/Reproducer.java) to determine +if a bug remains after statements have been removed or modified. +The reducer's responsibility is to verify if the current state, +formed by the pared-down statements, +continues to yield incorrect results for specific queries. + +Different oracles have distinct logic for determination, +meaning a universal reproducer doesn't exist. +Each oracle type needs its own reproducer implementation. +If reproducer is not implemented for specific oracle, +test case reduction is not available while using the oracle. + +Oracles for which reproducers have currently been implemented include: +1. for [`SQLite3NoRECOracle`](../src/sqlancer/sqlite3/oracle/SQLite3NoRECOracle.java) +2. for [`TiDBTLPWhereOracle`](../src/sqlancer/tidb/oracle/TiDBTLPWhereOracle.java) + +## Using reducers +Test-case reduction is disabled by default. The statement reducer can be enabled by passing `--use-reducer` when starting SQLancer. If you wish to further shorten each statements, you need to additionally pass the `--reduce-ast` parameter so that the AST-based reduction is applied. + +Note: if `--reduce-ast` is set, `--use-reducer` option must be enabled first. + +There are also options to define timeout seconds and max steps of reduction for both statement reducer and AST-based reducer. + +``` +--statement-reducer-max-steps= +--statement-reducer-max-time= +--ast-reducer-max-steps= +--ast-reducer-max-time= +``` + +## Reduction logs +If test-case reduction is enabled, each time the reducer performs a reduction step successfully,it prints the reduced statements to the log file, overwriting the previous ones. + +The log files will be stored in the following format: `logs//reduce/-reduce.log`. For instance, if the tested DBMS is SQLite3 and the current database is named database0, the log file will be located at `logs/sqlite3/reduce/database0-reduce.log`. diff --git a/pom.xml b/pom.xml index 234d805f1..7c9a1106b 100644 --- a/pom.xml +++ b/pom.xml @@ -44,7 +44,7 @@ org.apache.maven.plugins maven-shade-plugin - 3.3.0 + 3.4.0 package @@ -89,7 +89,7 @@ org.jacoco jacoco-maven-plugin - 0.8.8 + 0.8.12 @@ -123,7 +123,7 @@ org.codehaus.plexus plexus-compiler-eclipse - 2.12.1 + 2.13.0 org.eclipse.jdt @@ -133,7 +133,7 @@ org.codehaus.plexus plexus-compiler-api - 2.12.1 + 2.13.0 @@ -154,7 +154,7 @@ org.apache.maven.plugins maven-dependency-plugin - 3.2.0 + 3.4.0 copy-dependencies @@ -175,7 +175,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.2.2 + 3.3.0 true @@ -209,7 +209,7 @@ com.puppycrawl.tools checkstyle - 10.3.3 + 10.5.0 @@ -252,7 +252,7 @@ com.github.spotbugs spotbugs-maven-plugin - 4.7.1.1 + 4.7.3.0 spotbugs @@ -284,12 +284,22 @@ org.postgresql postgresql - 42.5.0 + 42.5.1 + + + com.ing.data + cassandra-jdbc-wrapper + 4.7.0 + + + com.yugabyte + jdbc-yugabytedb + 42.3.5-yb-1 org.xerial sqlite-jdbc - 3.36.0.3 + 3.49.1.0 mysql @@ -299,23 +309,28 @@ org.mariadb.jdbc mariadb-java-client - 3.0.7 + 3.1.0 org.duckdb duckdb_jdbc - 0.4.0 + 1.3.0.0 + + + com.facebook.presto + presto-jdbc + 0.283 org.junit.jupiter junit-jupiter-engine - 5.9.0 + 5.11.2 test org.slf4j - slf4j-simple - 1.7.36 + slf4j-simple + 2.0.6 ru.yandex.clickhouse @@ -325,7 +340,7 @@ com.h2database h2 - 2.1.214 + 2.3.232 org.mongodb @@ -337,6 +352,85 @@ arangodb-java-driver 6.9.0 + + org.questdb + questdb + 6.5.3 + + + org.hsqldb + hsqldb + 2.7.4 + runtime + + + org.apache.commons + commons-csv + 1.9.0 + + + com.github.jsqlparser + jsqlparser + 4.6 + + + org.apache.arrow + flight-sql-jdbc-driver + 16.1.0 + + + org.apache.hive + hive-jdbc + 3.1.2 + + + org.apache.logging.log4j + log4j-slf4j-impl + + + + + org.apache.hive + hive-serde + 4.0.1 + + + org.apache.logging.log4j + log4j-slf4j-impl + + + + + org.apache.hive + hive-cli + 4.0.1 + + + org.apache.logging.log4j + log4j-slf4j-impl + + + + + org.apache.logging.log4j + log4j-api + 2.24.3 + + + org.apache.logging.log4j + log4j-core + 2.24.3 + + + org.apache.logging.log4j + log4j-slf4j2-impl + 2.24.3 + + + org.apache.hadoop + hadoop-common + 3.2.4 + @@ -435,5 +529,23 @@ + + datafusion-tests + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.3.0 + + + **/TestDataFusion.java + + --add-opens java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED + + + + + diff --git a/src/check_names.py b/src/check_names.py index f76b881ab..f2ab346c6 100644 --- a/src/check_names.py +++ b/src/check_names.py @@ -1,28 +1,56 @@ import os +import sys +from typing import List -def get_java_files(directory): - java_files = [] - for root, dirs, files in os.walk(directory): - for f in files: - if f.endswith('.java'): - java_files.append(f) - return java_files - -def verify_prefix(prefix, files): - if len(files) == 0: - print(prefix + ' directory does not contain any files!') - exit(-1) - for f in files: - if not f.startswith(prefix): - print('The class name of ' + f + ' does not start with ' + prefix) - exit(-1) - -# TODO: ClickHouse (wait for https://github.com/sqlancer/sqlancer/pull/39) -verify_prefix('CockroachDB', get_java_files("sqlancer/cockroachdb/")) -verify_prefix('DuckDB', get_java_files("sqlancer/duckdb")) -verify_prefix('MariaDB', get_java_files("sqlancer/mariadb/")) -verify_prefix('MySQL', get_java_files("sqlancer/mysql/")) -verify_prefix('Postgres', get_java_files("sqlancer/postgres/")) -verify_prefix('SQLite3', get_java_files("sqlancer/sqlite3/")) -verify_prefix('TiDB', get_java_files("sqlancer/tidb/")) +def get_java_files(directory_path: str) -> List[str]: + java_files: List[str] = [] + for root, dirs, files in os.walk(directory_path): + for f in files: + if f.endswith('.java'): + java_files.append(f) + return java_files + + +def verify_one_db(prefix: str, files: List[str]): + print('checking database, name: {0}, files: {1}'.format(prefix, files)) + if len(files) == 0: + print(prefix + ' directory does not contain any files!', file=sys.stderr) + exit(-1) + for f in files: + if not f.startswith(prefix): + print('The class name of ' + f + ' does not start with ' + prefix, file=sys.stderr) + exit(-1) + print('checking database pass: ', prefix) + + +def verify_all_dbs(name_to_files: dict[str:List[str]]): + for db_name, files in name_to_files.items(): + verify_one_db(db_name, files) + + +if __name__ == '__main__': + cwd = os.getcwd() + print("Current working directory: {0}".format(cwd)) + name_to_files: dict[str:List[str]] = dict() + name_to_files["Citus"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "citus")) + name_to_files["ClickHouse"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "clickhouse")) + name_to_files["CnosDB"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "cnosdb")) + name_to_files["CockroachDB"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "cockroachdb")) + name_to_files["Databend"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "databend")) + name_to_files["DataFusion"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "datafusion")) + name_to_files["DuckDB"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "duckdb")) + name_to_files["H2"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "h2")) + name_to_files["HSQLDB"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "hsqldb")) + name_to_files["MariaDB"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "mariadb")) + name_to_files["Materialize"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "materialize")) + name_to_files["MySQL"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "mysql")) + name_to_files["OceanBase"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "oceanbase")) + name_to_files["Postgres"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "postgres")) + name_to_files["Presto"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "presto")) + name_to_files["QuestDB"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "questdb")) + name_to_files["SQLite3"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "sqlite3")) + name_to_files["TiDB"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "tidb")) + name_to_files["Y"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "yugabyte")) # has both YCQL and YSQL prefixes + name_to_files["Doris"] = get_java_files(os.path.join(cwd, "src", "sqlancer", "doris")) + verify_all_dbs(name_to_files) diff --git a/src/sqlancer/ASTBasedReducer.java b/src/sqlancer/ASTBasedReducer.java new file mode 100644 index 000000000..876a2da12 --- /dev/null +++ b/src/sqlancer/ASTBasedReducer.java @@ -0,0 +1,144 @@ +package sqlancer; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.Query; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.transformations.RemoveClausesOfSelect; +import sqlancer.transformations.RemoveColumnsOfSelect; +import sqlancer.transformations.RemoveElementsOfExpressionList; +import sqlancer.transformations.RemoveRowsOfInsert; +import sqlancer.transformations.RemoveUnions; +import sqlancer.transformations.RoundDoubleConstant; +import sqlancer.transformations.SimplifyConstant; +import sqlancer.transformations.SimplifyExpressions; +import sqlancer.transformations.Transformation; + +public class ASTBasedReducer, O extends DBMSSpecificOptions, C extends SQLancerDBConnection> + implements Reducer { + + private final DatabaseProvider provider; + + @SuppressWarnings("unused") + private G state; + private G newGlobalState; + private Reproducer reproducer; + + private List> reducedStatements; + // statement after reduction. + + public ASTBasedReducer(DatabaseProvider provider) { + this.provider = provider; + } + + @SuppressWarnings("unchecked") + private void updateStatements(String queryString, int index) { + boolean couldAffectSchema = queryString.contains("CREATE TABLE") || queryString.contains("EXPLAIN"); + reducedStatements.set(index, (Query) new SQLQueryAdapter(queryString, couldAffectSchema)); + } + + @SuppressWarnings("unchecked") + @Override + public void reduce(G state, Reproducer reproducer, G newGlobalState) throws Exception { + this.state = state; + this.newGlobalState = newGlobalState; + this.reproducer = reproducer; + + long maxReduceTime = state.getOptions().getMaxStatementReduceTime(); + long maxReduceSteps = state.getOptions().getMaxStatementReduceSteps(); + + List> initialBugInducingStatements = state.getState().getStatements(); + newGlobalState.getState().setStatements(new ArrayList<>(initialBugInducingStatements)); + + List transformations = new ArrayList<>(); + + transformations.add(new RemoveUnions()); + transformations.add(new RemoveClausesOfSelect()); + transformations.add(new RemoveRowsOfInsert()); + transformations.add(new RemoveColumnsOfSelect()); + transformations.add(new RemoveElementsOfExpressionList()); + transformations.add(new SimplifyExpressions()); + transformations.add(new SimplifyConstant()); + transformations.add(new RoundDoubleConstant()); + + Transformation.setBugJudgement(() -> { + try { + return this.bugStillTriggers(); + } catch (Exception ignored) { + } + return false; + }); + + boolean observeChange; + reducedStatements = new ArrayList<>(); + for (Query query : initialBugInducingStatements) { + reducedStatements.add((Query) query); + } + + Instant startTime = Instant.now(); + reduceProcess: do { + observeChange = false; + for (Transformation t : transformations) { + for (int i = 0; i < reducedStatements.size(); i++) { + + Instant currentTime = Instant.now(); + if (maxReduceTime != MainOptions.NO_REDUCE_LIMIT + && Duration.between(startTime, currentTime).getSeconds() >= maxReduceTime) { + break reduceProcess; + } + + if (maxReduceSteps != MainOptions.NO_REDUCE_LIMIT + && Transformation.getReduceSteps() >= maxReduceSteps) { + break reduceProcess; + } + + Query query = reducedStatements.get(i); + boolean initFlag = t.init(query.getQueryString()); + int index = i; + t.setStatementChangedCallBack((statementString) -> { + updateStatements(statementString, index); + }); + + if (!initFlag) { + newGlobalState.getLogger() + .logReducer("warning: failed parsing the statement at transformer : " + t); + continue; + } + t.apply(); + observeChange |= t.changed(); + } + } + } while (observeChange); + + newGlobalState.getState().setStatements(new ArrayList<>(reducedStatements)); + newGlobalState.getLogger().logReduced(newGlobalState.getState()); + } + + public boolean bugStillTriggers() throws Exception { + try (C con2 = provider.createDatabase(newGlobalState)) { + newGlobalState.setConnection(con2); + List> candidateStatements = new ArrayList<>(reducedStatements); + newGlobalState.getState().setStatements(new ArrayList<>(candidateStatements)); + + for (Query s : candidateStatements) { + try { + s.execute(newGlobalState); + } catch (Throwable ignoredException) { + // ignore + } + } + try { + if (reproducer.bugStillTriggers(newGlobalState)) { + newGlobalState.getLogger().logReduced(newGlobalState.getState()); + return true; + } + } catch (Throwable ignoredException) { + + } + } + return false; + } +} diff --git a/src/sqlancer/ComparatorHelper.java b/src/sqlancer/ComparatorHelper.java index 880d3ae72..cee290924 100644 --- a/src/sqlancer/ComparatorHelper.java +++ b/src/sqlancer/ComparatorHelper.java @@ -48,7 +48,8 @@ public static List getResultSetFirstColumnAsString(String queryString, E e.printStackTrace(); } } - SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors); + boolean canonicalizeString = state.getOptions().canonicalizeSqlString(); + SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors, true, canonicalizeString); List resultSet = new ArrayList<>(); SQLancerResultSet result = null; try { @@ -60,7 +61,7 @@ public static List getResultSetFirstColumnAsString(String queryString, E String resultTemp = result.getString(1); if (resultTemp != null) { resultTemp = resultTemp.replaceAll("[\\.]0+$", ""); // Remove the trailing zeros as many DBMS treat - // it as non-bugs + // it as non-bugs } resultSet.add(resultTemp); } @@ -69,11 +70,12 @@ public static List getResultSetFirstColumnAsString(String queryString, E throw e; } - if (e.getMessage() == null) { - throw new AssertionError(queryString, e); - } - if (errors.errorIsExpected(e.getMessage())) { - throw new IgnoreMeException(); + Throwable current = e; + while (current != null) { + if (current.getMessage() != null && errors.errorIsExpected(current.getMessage())) { + throw new IgnoreMeException(); + } + current = current.getCause(); } throw new AssertionError(queryString, e); } finally { @@ -87,32 +89,42 @@ public static List getResultSetFirstColumnAsString(String queryString, E public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, String originalQueryString, List combinedString, SQLGlobalState state) { if (resultSet.size() != secondResultSet.size()) { - String queryFormatString = "-- %s;\n-- cardinality: %d"; + String queryFormatString = "-- %s;" + System.lineSeparator() + "-- cardinality: %d" + + System.lineSeparator(); String firstQueryString = String.format(queryFormatString, originalQueryString, resultSet.size()); - String secondQueryString = String.format(queryFormatString, - combinedString.stream().collect(Collectors.joining(";")), secondResultSet.size()); - state.getState().getLocalState().log(String.format("%s\n%s", firstQueryString, secondQueryString)); - String assertionMessage = String.format("the size of the result sets mismatch (%d and %d)!\n%s\n%s", - resultSet.size(), secondResultSet.size(), firstQueryString, secondQueryString); + String combinedQueryString = String.join(";", combinedString); + String secondQueryString = String.format(queryFormatString, combinedQueryString, secondResultSet.size()); + state.getState().getLocalState() + .log(String.format("%s" + System.lineSeparator() + "%s", firstQueryString, secondQueryString)); + String assertionMessage = String.format( + "The size of the result sets mismatch (%d and %d)!" + System.lineSeparator() + + "First query: \"%s\", whose cardinality is: %d" + System.lineSeparator() + + "Second query:\"%s\", whose cardinality is: %d", + resultSet.size(), secondResultSet.size(), originalQueryString, resultSet.size(), + combinedQueryString, secondResultSet.size()); throw new AssertionError(assertionMessage); } Set firstHashSet = new HashSet<>(resultSet); Set secondHashSet = new HashSet<>(secondResultSet); - if (!firstHashSet.equals(secondHashSet)) { + boolean validateResultSizeOnly = state.getOptions().validateResultSizeOnly(); + if (!validateResultSizeOnly && !firstHashSet.equals(secondHashSet)) { Set firstResultSetMisses = new HashSet<>(firstHashSet); firstResultSetMisses.removeAll(secondHashSet); Set secondResultSetMisses = new HashSet<>(secondHashSet); secondResultSetMisses.removeAll(firstHashSet); - String queryFormatString = "-- %s;\n-- misses: %s"; + + String queryFormatString = "-- Query: \"%s\"; It misses: \"%s\""; String firstQueryString = String.format(queryFormatString, originalQueryString, firstResultSetMisses); - String secondQueryString = String.format(queryFormatString, - combinedString.stream().collect(Collectors.joining(";")), secondResultSetMisses); + String secondQueryString = String.format(queryFormatString, String.join(";", combinedString), + secondResultSetMisses); // update the SELECT queries to be logged at the bottom of the error log file - state.getState().getLocalState().log(String.format("%s\n%s", firstQueryString, secondQueryString)); - String assertionMessage = String.format("the content of the result sets mismatch!\n%s\n%s", - firstQueryString, secondQueryString); + state.getState().getLocalState() + .log(String.format("%s" + System.lineSeparator() + "%s", firstQueryString, secondQueryString)); + String assertionMessage = String.format("The content of the result sets mismatch!" + System.lineSeparator() + + "First query : \"%s\"" + System.lineSeparator() + "Second query: \"%s\"", originalQueryString, + secondQueryString); throw new AssertionError(assertionMessage); } } @@ -166,4 +178,20 @@ public static List getCombinedResultSetNoDuplicates(String firstQueryStr return secondResultSet; } + public static String canonicalizeResultValue(String value) { + if (value == null) { + return value; + } + + switch (value) { + case "-0.0": + return "0.0"; + case "-0": + return "0"; + default: + } + + return value; + } + } diff --git a/src/sqlancer/DatabaseProvider.java b/src/sqlancer/DatabaseProvider.java index 72acb30d8..d169324fa 100644 --- a/src/sqlancer/DatabaseProvider.java +++ b/src/sqlancer/DatabaseProvider.java @@ -24,11 +24,25 @@ public interface DatabaseProvider, O extends DBMS * @param globalState * the state created and is valid for this method call. * + * @return Reproducer if a bug is found and a reproducer is available. + * * @throws Exception * if creating the database fails. * */ - void generateAndTestDatabase(G globalState) throws Exception; + Reproducer generateAndTestDatabase(G globalState) throws Exception; + + /** + * The experimental feature: Query Plan Guidance. + * + * @param globalState + * the state created and is valid for this method call. + * + * @throws Exception + * if testing fails. + * + */ + void generateAndTestDatabaseWithQueryPlanGuidance(G globalState) throws Exception; C createDatabase(G globalState) throws Exception; diff --git a/src/sqlancer/GlobalState.java b/src/sqlancer/GlobalState.java index 64b5c731e..2b93012c2 100644 --- a/src/sqlancer/GlobalState.java +++ b/src/sqlancer/GlobalState.java @@ -131,7 +131,7 @@ public S getSchema() { try { updateSchema(); } catch (Exception e) { - throw new AssertionError(); + throw new AssertionError(e.getMessage()); } } return schema; diff --git a/src/sqlancer/Main.java b/src/sqlancer/Main.java index a4404a6fa..faf35e3c9 100644 --- a/src/sqlancer/Main.java +++ b/src/sqlancer/Main.java @@ -5,6 +5,7 @@ import java.io.IOException; import java.io.Writer; import java.nio.file.Files; +import java.nio.file.Path; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.ArrayList; @@ -23,9 +24,31 @@ import com.beust.jcommander.JCommander; import com.beust.jcommander.JCommander.Builder; +import sqlancer.citus.CitusProvider; +import sqlancer.clickhouse.ClickHouseProvider; +import sqlancer.cnosdb.CnosDBProvider; +import sqlancer.cockroachdb.CockroachDBProvider; import sqlancer.common.log.Loggable; import sqlancer.common.query.Query; import sqlancer.common.query.SQLancerResultSet; +import sqlancer.databend.DatabendProvider; +import sqlancer.doris.DorisProvider; +import sqlancer.duckdb.DuckDBProvider; +import sqlancer.h2.H2Provider; +import sqlancer.hive.HiveProvider; +import sqlancer.hsqldb.HSQLDBProvider; +import sqlancer.mariadb.MariaDBProvider; +import sqlancer.materialize.MaterializeProvider; +import sqlancer.mysql.MySQLProvider; +import sqlancer.oceanbase.OceanBaseProvider; +import sqlancer.postgres.PostgresProvider; +import sqlancer.presto.PrestoProvider; +import sqlancer.questdb.QuestDBProvider; +import sqlancer.spark.SparkProvider; +import sqlancer.sqlite3.SQLite3Provider; +import sqlancer.tidb.TiDBProvider; +import sqlancer.yugabyte.ycql.YCQLProvider; +import sqlancer.yugabyte.ysql.YSQLProvider; public final class Main { @@ -38,7 +61,7 @@ public final class Main { static boolean progressMonitorStarted; static { - System.setProperty(org.slf4j.impl.SimpleLogger.DEFAULT_LOG_LEVEL_KEY, "ERROR"); + System.setProperty(org.slf4j.simple.SimpleLogger.DEFAULT_LOG_LEVEL_KEY, "ERROR"); if (!LOG_DIRECTORY.exists()) { LOG_DIRECTORY.mkdir(); } @@ -51,10 +74,19 @@ public static final class StateLogger { private final File loggerFile; private File curFile; + private File queryPlanFile; + private File reduceFile; private FileWriter logFileWriter; public FileWriter currentFileWriter; + private FileWriter queryPlanFileWriter; + private FileWriter reduceFileWriter; + private Path reproduceFilePath; + private static final List INITIALIZED_PROVIDER_NAMES = new ArrayList<>(); private final boolean logEachSelect; + private final boolean logQueryPlan; + + private final boolean useReducer; private final DatabaseProvider databaseProvider; private static final class AlsoWriteToConsoleFileWriter extends FileWriter { @@ -87,6 +119,25 @@ public StateLogger(String databaseName, DatabaseProvider provider, Main if (logEachSelect) { curFile = new File(dir, databaseName + "-cur.log"); } + logQueryPlan = options.logQueryPlan(); + if (logQueryPlan) { + queryPlanFile = new File(dir, databaseName + "-plan.log"); + } + this.useReducer = options.useReducer(); + if (useReducer) { + File reduceFileDir = new File(dir, "reduce"); + if (!reduceFileDir.exists()) { + reduceFileDir.mkdir(); + } + this.reduceFile = new File(reduceFileDir, databaseName + "-reduce.log"); + } + if (options.serializeReproduceState()) { + File reproduceFileDir = new File(dir, "reproduce"); + if (!reproduceFileDir.exists()) { + reproduceFileDir.mkdir(); + } + reproduceFilePath = new File(reproduceFileDir, databaseName + ".ser").toPath(); + } this.databaseProvider = provider; } @@ -138,6 +189,34 @@ public FileWriter getCurrentFileWriter() { return currentFileWriter; } + public FileWriter getQueryPlanFileWriter() { + if (!logQueryPlan) { + throw new UnsupportedOperationException(); + } + if (queryPlanFileWriter == null) { + try { + queryPlanFileWriter = new FileWriter(queryPlanFile, true); + } catch (IOException e) { + throw new AssertionError(e); + } + } + return queryPlanFileWriter; + } + + public FileWriter getReduceFileWriter() { + if (!useReducer) { + throw new UnsupportedOperationException(); + } + if (reduceFileWriter == null) { + try { + reduceFileWriter = new FileWriter(reduceFile, false); + } catch (IOException e) { + throw new AssertionError(e); + } + } + return reduceFileWriter; + } + public void writeCurrent(StateToReproduce state) { if (!logEachSelect) { throw new UnsupportedOperationException(); @@ -172,6 +251,61 @@ private void write(Loggable loggable) { } } + public void writeQueryPlan(String queryPlan) { + if (!logQueryPlan) { + throw new UnsupportedOperationException(); + } + try { + getQueryPlanFileWriter().append(removeNamesFromQueryPlans(queryPlan)); + queryPlanFileWriter.flush(); + } catch (IOException e) { + throw new AssertionError(); + } + } + + public void logReducer(String reducerLog) { + FileWriter reduceFileWriter = getReduceFileWriter(); + + StringBuilder sb = new StringBuilder(); + sb.append("[reducer log] "); + sb.append(reducerLog); + try { + reduceFileWriter.write(sb.toString()); + } catch (IOException e) { + throw new AssertionError(e); + } finally { + try { + reduceFileWriter.flush(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + } + + public void logReduced(StateToReproduce state) { + FileWriter reduceFileWriter = getReduceFileWriter(); + + StringBuilder sb = new StringBuilder(); + for (Query s : state.getStatements()) { + sb.append(databaseProvider.getLoggableFactory().createLoggable(s.getLogString()).getLogString()); + } + try { + reduceFileWriter.write(sb.toString()); + + } catch (IOException e) { + throw new AssertionError(e); + } finally { + try { + reduceFileWriter.flush(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + + } + public void logException(Throwable reduce, StateToReproduce state) { Loggable stackTrace = getStackTrace(reduce); FileWriter logFileWriter2 = getLogFileWriter(); @@ -184,7 +318,6 @@ public void logException(Throwable reduce, StateToReproduce state) { try { logFileWriter2.flush(); } catch (IOException e) { - // TODO Auto-generated catch block e.printStackTrace(); } } @@ -201,8 +334,7 @@ private void printState(FileWriter writer, StateToReproduce state) { .getInfo(state.getDatabaseName(), state.getDatabaseVersion(), state.getSeedValue()).getLogString()); for (Query s : state.getStatements()) { - sb.append(s.getLogString()); - sb.append('\n'); + sb.append(databaseProvider.getLoggableFactory().createLoggable(s.getLogString()).getLogString()); } try { writer.write(sb.toString()); @@ -211,6 +343,17 @@ private void printState(FileWriter writer, StateToReproduce state) { } } + private String removeNamesFromQueryPlans(String queryPlan) { + String result = queryPlan; + result = result.replaceAll("t[0-9]+", "t0"); // Avoid duplicate tables + result = result.replaceAll("v[0-9]+", "v0"); // Avoid duplicate views + result = result.replaceAll("i[0-9]+", "i0"); // Avoid duplicate indexes + return result + "\n"; + } + + public Path getReproduceFilePath() { + return reproduceFilePath; + } } public static class QueryManager { @@ -222,10 +365,12 @@ public static class QueryManager { } public boolean execute(Query q, String... fills) throws Exception { - globalState.getState().logStatement(q); boolean success; success = q.execute(globalState, fills); Main.nrSuccessfulActions.addAndGet(1); + if (globalState.getOptions().loggerPrintFailed() || success) { + globalState.getState().logStatement(q); + } return success; } @@ -241,6 +386,10 @@ public void incrementSelectQueryCount() { Main.nrQueries.addAndGet(1); } + public Long getSelectQueryCount() { + return Main.nrQueries.get(); + } + public void incrementCreateDatabase() { Main.nrDatabases.addAndGet(1); } @@ -312,13 +461,57 @@ public void run() throws Exception { if (options.logEachSelect()) { logger.writeCurrent(state.getState()); } - provider.generateAndTestDatabase(state); + Reproducer reproducer = null; + if (options.enableQPG()) { + provider.generateAndTestDatabaseWithQueryPlanGuidance(state); + } else { + reproducer = provider.generateAndTestDatabase(state); + } try { logger.getCurrentFileWriter().close(); logger.currentFileWriter = null; } catch (IOException e) { throw new AssertionError(e); } + + if (options.serializeReproduceState() && reproducer != null) { + stateToRepro.serialize(logger.getReproduceFilePath()); + } + if (options.reduceAST() && !options.useReducer()) { + throw new AssertionError("To reduce AST, use-reducer option must be enabled first"); + } + if (options.useReducer()) { + if (reproducer == null) { + logger.getReduceFileWriter().write("current oracle does not support experimental reducer."); + throw new IgnoreMeException(); + } + G newGlobalState = createGlobalState(); + newGlobalState.setState(stateToRepro); + newGlobalState.setRandomly(r); + newGlobalState.setDatabaseName(databaseName); + newGlobalState.setMainOptions(options); + newGlobalState.setDbmsSpecificOptions(command); + QueryManager newManager = new QueryManager<>(newGlobalState); + newGlobalState.setStateLogger(new StateLogger(databaseName, provider, options)); + newGlobalState.setManager(newManager); + + Reducer reducer = new StatementReducer<>(provider); + reducer.reduce(state, reproducer, newGlobalState); + + if (options.reduceAST()) { + Reducer astBasedReducer = new ASTBasedReducer<>(provider); + astBasedReducer.reduce(state, reproducer, newGlobalState); + } + + try { + logger.getReduceFileWriter().close(); + logger.reduceFileWriter = null; + } catch (IOException e) { + throw new AssertionError(e); + } + + throw new AssertionError("Found a potential bug, please check reducer log for detail."); + } } } @@ -419,7 +612,7 @@ public void run() { System.out.println( formatInteger(nrSuccessfulActions.get()) + " successfully-executed statements"); System.out.println( - formatInteger(nrUnsuccessfulActions.get()) + " unsuccessfuly-executed statements"); + formatInteger(nrUnsuccessfulActions.get()) + " unsuccessfully-executed statements"); } private String formatInteger(long intValue) { @@ -498,6 +691,10 @@ private boolean run(MainOptions options, ExecutorService execService, executor.getStateToReproduce().exception = reduce.getMessage(); executor.getLogger().logFileWriter = null; executor.getLogger().logException(reduce, executor.getStateToReproduce()); + if (options.serializeReproduceState()) { + executor.getStateToReproduce().logStatement(reduce.getMessage()); // add the error statement + executor.getStateToReproduce().serialize(executor.getLogger().getReproduceFilePath()); + } return false; } finally { try { @@ -542,9 +739,40 @@ private boolean run(MainOptions options, ExecutorService execService, for (DatabaseProvider provider : loader) { providers.add(provider); } + checkForIssue799(providers); return providers; } + // see https://github.com/sqlancer/sqlancer/issues/799 + private static void checkForIssue799(List> providers) { + if (providers.isEmpty()) { + System.err.println( + "No DBMS implementations (i.e., instantiations of the DatabaseProvider class) were found. You likely ran into an issue described in https://github.com/sqlancer/sqlancer/issues/799. As a workaround, I now statically load all supported providers as of June 7, 2023."); + providers.add(new CitusProvider()); + providers.add(new ClickHouseProvider()); + providers.add(new CnosDBProvider()); + providers.add(new CockroachDBProvider()); + providers.add(new DatabendProvider()); + providers.add(new DorisProvider()); + providers.add(new DuckDBProvider()); + providers.add(new H2Provider()); + providers.add(new HiveProvider()); + providers.add(new SparkProvider()); + providers.add(new HSQLDBProvider()); + providers.add(new MariaDBProvider()); + providers.add(new MaterializeProvider()); + providers.add(new MySQLProvider()); + providers.add(new OceanBaseProvider()); + providers.add(new PrestoProvider()); + providers.add(new PostgresProvider()); + providers.add(new QuestDBProvider()); + providers.add(new SQLite3Provider()); + providers.add(new TiDBProvider()); + providers.add(new YCQLProvider()); + providers.add(new YSQLProvider()); + } + } + private static synchronized void startProgressMonitor() { if (progressMonitorStarted) { /* diff --git a/src/sqlancer/MainOptions.java b/src/sqlancer/MainOptions.java index ff588de5f..a5142fcf0 100644 --- a/src/sqlancer/MainOptions.java +++ b/src/sqlancer/MainOptions.java @@ -1,5 +1,7 @@ package sqlancer; +import java.util.Objects; + import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameters; @@ -8,6 +10,8 @@ @Parameters(separators = "=", commandDescription = "Options applicable to all DBMS") public class MainOptions { public static final int NO_SET_PORT = -1; + public static final int NO_REDUCE_LIMIT = -1; + public static final MainOptions DEFAULT_OPTIONS = new MainOptions(); @Parameter(names = { "--help", "-h" }, description = "Lists all supported options and commands", help = true) private boolean help; // NOPMD @@ -44,6 +48,24 @@ public class MainOptions { @Parameter(names = "--log-execution-time", description = "Logs the execution time of each statement (requires --log-each-select to be enabled)", arity = 1) private boolean logExecutionTime = true; // NOPMD + @Parameter(names = "--print-failed", description = "Logs failed insert, create and other statements without results", arity = 1) + private boolean loggerPrintFailed = true; // NOPMD + + @Parameter(names = "--qpg-enable", description = "Enable the experimental feature Query Plan Guidance (QPG)", arity = 1) + private boolean enableQPG; + + @Parameter(names = "--qpg-log-query-plan", description = "Logs the query plans of each query (requires --qpg-enable)", arity = 1) + private boolean logQueryPlan; + + @Parameter(names = "--qpg-max-interval", description = "The maximum number of iterations to mutate tables if no new query plans (requires --qpg-enable)") + private static int qpgMaxInterval = 1000; + + @Parameter(names = "--qpg-reward-weight", description = "The weight (0-1) of last reward when updating weighted average reward. A higher value denotes average reward is more affected by the last reward (requires --qpg-enable)") + private static double qpgk = 0.25; + + @Parameter(names = "--qpg-selection-probability", description = "The probability (0-1) of the random selection of mutators. A higher value (>0.5) favors exploration over exploitation. (requires --qpg-enable)") + private static double qpgProbability = 0.7; + @Parameter(names = "--username", description = "The user name used to log into the DBMS") private String userName = "sqlancer"; // NOPMD @@ -101,6 +123,33 @@ public class MainOptions { @Parameter(names = "--database-prefix", description = "The prefix used for each database created") private String databasePrefix = "database"; // NOPMD + @Parameter(names = "--serialize-reproduce-state", description = "Serialize the state to reproduce") + private boolean serializeReproduceState = false; // NOPMD + + @Parameter(names = "--use-reducer", description = "EXPERIMENTAL Attempt to reduce queries using a simple reducer") + private boolean useReducer = false; // NOPMD + + @Parameter(names = "--reduce-ast", description = "EXPERIMENTAL perform AST reduction after statement reduction") + private boolean reduceAST = false; // NOPMD + + @Parameter(names = "--statement-reducer-max-steps", description = "EXPERIMENTAL Maximum steps the statement reducer will do") + private long maxStatementReduceSteps = NO_REDUCE_LIMIT; // NOPMD + + @Parameter(names = "--statement-reducer-max-time", description = "EXPERIMENTAL Maximum time duration (secs) the AST-based reducer will do") + private long maxASTReduceTime = NO_REDUCE_LIMIT; // NOPMD + + @Parameter(names = "--ast-reducer-max-steps", description = "EXPERIMENTAL Maximum steps the AST-based reducer will do") + private long maxASTReduceSteps = NO_REDUCE_LIMIT; // NOPMD + + @Parameter(names = "--ast-reducer-max-time", description = "EXPERIMENTAL Maximum time duration (secs) the statement reducer will do") + private long maxStatementReduceTime = NO_REDUCE_LIMIT; // NOPMD + + @Parameter(names = "--validate-result-size-only", description = "Should validate result size only and skip comparing content of the result set ", arity = 1) + private boolean validateResultSizeOnly = false; // NOPMD + + @Parameter(names = "--canonicalize-sql-strings", description = "Should canonicalize query string (add ';' at the end", arity = 1) + private boolean canonicalizeSqlString = true; // NOPMD + public int getMaxExpressionDepth() { return maxExpressionDepth; } @@ -138,6 +187,30 @@ public boolean logExecutionTime() { return logExecutionTime; } + public boolean loggerPrintFailed() { + return loggerPrintFailed; + } + + public boolean logQueryPlan() { + return logQueryPlan; + } + + public boolean enableQPG() { + return enableQPG; + } + + public int getQPGMaxMutationInterval() { + return qpgMaxInterval; + } + + public double getQPGk() { + return qpgk; + } + + public double getQPGProbability() { + return qpgProbability; + } + public int getNrQueries() { return nrQueries; } @@ -218,6 +291,14 @@ public boolean isHelp() { return help; } + public boolean isDefaultPassword() { + return Objects.equals(password, DEFAULT_OPTIONS.password); + } + + public boolean isDefaultUsername() { + return Objects.equals(userName, DEFAULT_OPTIONS.userName); + } + public String getDatabasePrefix() { return databasePrefix; } @@ -226,4 +307,40 @@ public boolean performConnectionTest() { return useConnectionTest; } + public boolean serializeReproduceState() { + return serializeReproduceState; + } + + public boolean useReducer() { + return useReducer; + } + + public boolean reduceAST() { + return reduceAST; + } + + public long getMaxStatementReduceSteps() { + return maxStatementReduceSteps; + } + + public long getMaxStatementReduceTime() { + return maxStatementReduceTime; + } + + public long getMaxASTReduceSteps() { + return maxASTReduceSteps; + } + + public long getMaxASTReduceTime() { + return maxASTReduceTime; + } + + public boolean validateResultSizeOnly() { + return validateResultSizeOnly; + } + + public boolean canonicalizeSqlString() { + return canonicalizeSqlString; + } + } diff --git a/src/sqlancer/OracleFactory.java b/src/sqlancer/OracleFactory.java index 897293c6b..9d6e1704b 100644 --- a/src/sqlancer/OracleFactory.java +++ b/src/sqlancer/OracleFactory.java @@ -4,7 +4,7 @@ public interface OracleFactory> { - TestOracle create(G globalState) throws Exception; + TestOracle create(G globalState) throws Exception; /** * Indicates whether the test oracle requires that all tables (including views) contain at least one row. diff --git a/src/sqlancer/ProviderAdapter.java b/src/sqlancer/ProviderAdapter.java index a16f8388c..346567300 100644 --- a/src/sqlancer/ProviderAdapter.java +++ b/src/sqlancer/ProviderAdapter.java @@ -1,9 +1,14 @@ package sqlancer; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import sqlancer.StateToReproduce.OracleRunReproductionState; +import sqlancer.common.DBMSCommon; import sqlancer.common.oracle.CompositeTestOracle; import sqlancer.common.oracle.TestOracle; import sqlancer.common.schema.AbstractSchema; @@ -14,7 +19,14 @@ public abstract class ProviderAdapter globalClass; private final Class optionClass; - public ProviderAdapter(Class globalClass, Class optionClass) { + // Variables for QPG + Map queryPlanPool = new HashMap<>(); + static double[] weightedAverageReward; // static variable for sharing across all threads + int currentSelectRewards; + int currentSelectCounts; + int currentMutationOperator = -1; + + protected ProviderAdapter(Class globalClass, Class optionClass) { this.globalClass = globalClass; this.optionClass = optionClass; } @@ -35,47 +47,56 @@ public Class getOptionClass() { } @Override - public void generateAndTestDatabase(G globalState) throws Exception { + public Reproducer generateAndTestDatabase(G globalState) throws Exception { try { generateDatabase(globalState); checkViewsAreValid(globalState); globalState.getManager().incrementCreateDatabase(); - TestOracle oracle = getTestOracle(globalState); + TestOracle oracle = getTestOracle(globalState); for (int i = 0; i < globalState.getOptions().getNrQueries(); i++) { try (OracleRunReproductionState localState = globalState.getState().createLocalState()) { assert localState != null; try { oracle.check(); globalState.getManager().incrementSelectQueryCount(); - } catch (IgnoreMeException e) { - + } catch (IgnoreMeException ignored) { + } catch (AssertionError e) { + Reproducer reproducer = oracle.getLastReproducer(); + if (reproducer != null) { + return reproducer; + } + throw e; } - assert localState != null; localState.executedWithoutError(); } } } finally { globalState.getConnection().close(); } + return null; } - protected abstract void checkViewsAreValid(G globalState); + protected abstract void checkViewsAreValid(G globalState) throws SQLException; - protected TestOracle getTestOracle(G globalState) throws Exception { + protected TestOracle getTestOracle(G globalState) throws Exception { List> testOracleFactory = globalState.getDbmsSpecificOptions() .getTestOracleFactory(); boolean testOracleRequiresMoreThanZeroRows = testOracleFactory.stream() - .anyMatch(p -> p.requiresAllTablesToContainRows()); + .anyMatch(OracleFactory::requiresAllTablesToContainRows); boolean userRequiresMoreThanZeroRows = globalState.getOptions().testOnlyWithMoreThanZeroRows(); boolean checkZeroRows = testOracleRequiresMoreThanZeroRows || userRequiresMoreThanZeroRows; if (checkZeroRows && globalState.getSchema().containsTableWithZeroRows(globalState)) { - throw new IgnoreMeException(); + if (globalState.getOptions().enableQPG()) { + addRowsToAllTables(globalState); + } else { + throw new IgnoreMeException(); + } } if (testOracleFactory.size() == 1) { return testOracleFactory.get(0).create(globalState); } else { - return new CompositeTestOracle(testOracleFactory.stream().map(o -> { + return new CompositeTestOracle<>(testOracleFactory.stream().map(o -> { try { return o.create(globalState); } catch (Exception e1) { @@ -87,4 +108,151 @@ protected TestOracle getTestOracle(G globalState) throws Exception { public abstract void generateDatabase(G globalState) throws Exception; + // QPG: entry function + @Override + public void generateAndTestDatabaseWithQueryPlanGuidance(G globalState) throws Exception { + if (weightedAverageReward == null) { + weightedAverageReward = initializeWeightedAverageReward(); // Same length as the list of mutators + } + try { + generateDatabase(globalState); + checkViewsAreValid(globalState); + globalState.getManager().incrementCreateDatabase(); + + Long executedQueryCount = 0L; + while (executedQueryCount < globalState.getOptions().getNrQueries()) { + int numOfNoNewQueryPlans = 0; + TestOracle oracle = getTestOracle(globalState); + while (executedQueryCount < globalState.getOptions().getNrQueries()) { + try (OracleRunReproductionState localState = globalState.getState().createLocalState()) { + assert localState != null; + try { + oracle.check(); + String query = oracle.getLastQueryString(); + executedQueryCount += 1; + if (addQueryPlan(query, globalState)) { + numOfNoNewQueryPlans = 0; + } else { + numOfNoNewQueryPlans++; + } + globalState.getManager().incrementSelectQueryCount(); + } catch (IgnoreMeException e) { + + } + localState.executedWithoutError(); + } + // exit loop to mutate tables if no new query plans have been found after a while + if (numOfNoNewQueryPlans > globalState.getOptions().getQPGMaxMutationInterval()) { + mutateTables(globalState); + break; + } + } + } + } finally { + globalState.getConnection().close(); + } + } + + // QPG: mutate tables for a new database state + private synchronized boolean mutateTables(G globalState) throws Exception { + // Update rewards based on a set of newly generated queries in last iteration + if (currentMutationOperator != -1) { + weightedAverageReward[currentMutationOperator] += ((double) currentSelectRewards + / (double) currentSelectCounts) * globalState.getOptions().getQPGk(); + } + currentMutationOperator = -1; + + // Choose mutator based on the rewards + int selectedActionIndex = 0; + if (Randomly.getPercentage() < globalState.getOptions().getQPGProbability()) { + selectedActionIndex = globalState.getRandomly().getInteger(0, weightedAverageReward.length); + } else { + selectedActionIndex = DBMSCommon.getMaxIndexInDoubleArray(weightedAverageReward); + } + int reward = 0; + + try { + executeMutator(selectedActionIndex, globalState); + checkViewsAreValid(globalState); // Remove the invalid views + reward = checkQueryPlan(globalState); + } catch (IgnoreMeException | AssertionError e) { + } finally { + // Update rewards based on existing queries associated with the query plan pool + updateReward(selectedActionIndex, (double) reward / (double) queryPlanPool.size(), globalState); + currentMutationOperator = selectedActionIndex; + } + + // Clear the variables for storing the rewards of the action on a set of newly generated queries + currentSelectRewards = 0; + currentSelectCounts = 0; + return true; + } + + // QPG: add a query plan to the query plan pool and return true if the query plan is new + private boolean addQueryPlan(String selectStr, G globalState) throws Exception { + String queryPlan = getQueryPlan(selectStr, globalState); + + if (globalState.getOptions().logQueryPlan()) { + globalState.getLogger().writeQueryPlan(queryPlan); + } + + currentSelectCounts += 1; + if (queryPlanPool.containsKey(queryPlan)) { + return false; + } else { + queryPlanPool.put(queryPlan, selectStr); + currentSelectRewards += 1; + return true; + } + } + + // Obtain the reward of the current action based on the queries associated with the query plan pool + private int checkQueryPlan(G globalState) throws Exception { + int newQueryPlanFound = 0; + HashMap modifiedQueryPlan = new HashMap<>(); + for (Iterator> it = queryPlanPool.entrySet().iterator(); it.hasNext();) { + Map.Entry item = it.next(); + String queryPlan = item.getKey(); + String selectStr = item.getValue(); + String newQueryPlan = getQueryPlan(selectStr, globalState); + if (newQueryPlan.isEmpty()) { // Invalid query + it.remove(); + } else if (!queryPlan.equals(newQueryPlan)) { // A query plan has been changed + it.remove(); + modifiedQueryPlan.put(newQueryPlan, selectStr); + if (!queryPlanPool.containsKey(newQueryPlan)) { // A new query plan is found + newQueryPlanFound++; + } + } + } + queryPlanPool.putAll(modifiedQueryPlan); + return newQueryPlanFound; + } + + // QPG: update the reward of current action + private void updateReward(int actionIndex, double reward, G globalState) { + weightedAverageReward[actionIndex] += (reward - weightedAverageReward[actionIndex]) + * globalState.getOptions().getQPGk(); + } + + // QPG: initialize the weighted average reward of all mutation operators (required implementation in specific DBMS) + protected double[] initializeWeightedAverageReward() { + throw new UnsupportedOperationException(); + } + + // QPG: obtain the query plan of a query (required implementation in specific DBMS) + protected String getQueryPlan(String selectStr, G globalState) throws Exception { + throw new UnsupportedOperationException(); + } + + // QPG: execute a mutation operator (required implementation in specific DBMS) + protected void executeMutator(int index, G globalState) throws Exception { + throw new UnsupportedOperationException(); + } + + // QPG: add rows to all tables (required implementation in specific DBMS when enabling PQS oracle for QPG) + protected boolean addRowsToAllTables(G globalState) throws Exception { + throw new UnsupportedOperationException(); + } + } diff --git a/src/sqlancer/Randomly.java b/src/sqlancer/Randomly.java index fa317f1f0..8494c189a 100644 --- a/src/sqlancer/Randomly.java +++ b/src/sqlancer/Randomly.java @@ -1,8 +1,10 @@ package sqlancer; import java.math.BigDecimal; +import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Random; import java.util.function.Supplier; @@ -15,6 +17,7 @@ public final class Randomly { private static int cacheSize = 100; private final List cachedLongs = new ArrayList<>(); + private final List cachedIntegers = new ArrayList<>(); private final List cachedStrings = new ArrayList<>(); private final List cachedDoubles = new ArrayList<>(); private final List cachedBytes = new ArrayList<>(); @@ -29,6 +32,12 @@ private void addToCache(long val) { } } + private void addToCache(int val) { + if (useCaching && cachedIntegers.size() < cacheSize && !cachedIntegers.contains(val)) { + cachedIntegers.add(val); + } + } + private void addToCache(double val) { if (useCaching && cachedDoubles.size() < cacheSize && !cachedDoubles.contains(val)) { cachedDoubles.add(val); @@ -49,6 +58,14 @@ private Long getFromLongCache() { } } + private Integer getFromIntegerCache() { + if (!useCaching || cachedIntegers.isEmpty()) { + return null; + } else { + return Randomly.fromList(cachedIntegers); + } + } + private Double getFromDoubleCache() { if (!useCaching) { return null; @@ -118,6 +135,12 @@ public static List nonEmptySubset(List columns, int nr) { return extractNrRandomColumns(columns, nr); } + public static List nonEmptySubsetLeast(List columns, int min) { + int nr = getNextInt(min, columns.size() + 1); + assert nr <= columns.size(); + return extractNrRandomColumns(columns, nr); + } + public static List nonEmptySubsetPotentialDuplicates(List columns) { List arr = new ArrayList<>(); for (int i = 0; i < Randomly.smallNumber() + 1; i++) { @@ -133,17 +156,12 @@ public static List subset(List columns) { public static List subset(int nr, @SuppressWarnings("unchecked") T... values) { List list = new ArrayList<>(); - for (T val : values) { - list.add(val); - } + Collections.addAll(list, values); return extractNrRandomColumns(list, nr); } public static List subset(@SuppressWarnings("unchecked") T... values) { - List list = new ArrayList<>(); - for (T val : values) { - list.add(val); - } + List list = new ArrayList<>(Arrays.asList(values)); return subset(list); } @@ -204,7 +222,6 @@ public String getString(Randomly r) { }, ALPHANUMERIC { - @Override public String getString(Randomly r) { return getStringOfAlphabet(r, ALPHANUMERIC_ALPHABET); @@ -213,7 +230,6 @@ public String getString(Randomly r) { }, ALPHANUMERIC_SPECIALCHAR { - @Override public String getString(Randomly r) { return getStringOfAlphabet(r, ALPHANUMERIC_SPECIALCHAR_ALPHABET); @@ -350,7 +366,6 @@ public long getNonZeroInteger() { do { value = getInteger(); } while (value == 0); - assert value != 0; addToCache(value); return value; } @@ -373,6 +388,24 @@ public long getPositiveInteger() { return value; } + public int getPositiveIntegerInt() { + if (cacheProbability()) { + Integer value = getFromIntegerCache(); + if (value != null && value >= 0) { + return value; + } + } + int value; + if (smallBiasProbability()) { + value = Randomly.fromOptions(0, Integer.MAX_VALUE, 1); + } else { + value = getNextInt(0, Integer.MAX_VALUE); + } + addToCache(value); + assert value >= 0; + return value; + } + public double getFiniteDouble() { while (true) { double val = getDouble(); @@ -424,8 +457,19 @@ public long getLong(long left, long right) { return getNextLong(left, right); } + public BigInteger getBigInteger(BigInteger left, BigInteger right) { + if (left.equals(right)) { + return left; + } + BigInteger result = new BigInteger(String.valueOf(getInteger(left.intValue(), right.intValue()))); + if (result.compareTo(left) < 0 && result.compareTo(right) > 0) { + throw new IgnoreMeException(); + } + return result; + } + public BigDecimal getRandomBigDecimal() { - return new BigDecimal(getThreadRandom().get().nextDouble()); + return BigDecimal.valueOf(getThreadRandom().get().nextDouble()); } public long getPositiveIntegerNotNull() { @@ -494,7 +538,7 @@ private static long getNextLong(long lower, long upper) { if (lower == upper) { return lower; } - return (long) (getThreadRandom().get().longs(lower, upper).findFirst().getAsLong()); + return getThreadRandom().get().longs(lower, upper).findFirst().getAsLong(); } private static int getNextInt(int lower, int upper) { diff --git a/src/sqlancer/Reducer.java b/src/sqlancer/Reducer.java new file mode 100644 index 000000000..0e6589262 --- /dev/null +++ b/src/sqlancer/Reducer.java @@ -0,0 +1,7 @@ +package sqlancer; + +public interface Reducer> { + + void reduce(G state, Reproducer reproducer, G newGlobalState) throws Exception; + +} diff --git a/src/sqlancer/Reproducer.java b/src/sqlancer/Reproducer.java new file mode 100644 index 000000000..ef64bd0fe --- /dev/null +++ b/src/sqlancer/Reproducer.java @@ -0,0 +1,5 @@ +package sqlancer; + +public interface Reproducer> { + boolean bugStillTriggers(G globalState); +} diff --git a/src/sqlancer/SQLProviderAdapter.java b/src/sqlancer/SQLProviderAdapter.java index 8de7523a3..efb4fab67 100644 --- a/src/sqlancer/SQLProviderAdapter.java +++ b/src/sqlancer/SQLProviderAdapter.java @@ -10,7 +10,7 @@ public abstract class SQLProviderAdapter>, O extends DBMSSpecificOptions>> extends ProviderAdapter { - public SQLProviderAdapter(Class globalClass, Class optionClass) { + protected SQLProviderAdapter(Class globalClass, Class optionClass) { super(globalClass, optionClass); } diff --git a/src/sqlancer/StandaloneReducer.java b/src/sqlancer/StandaloneReducer.java new file mode 100644 index 000000000..813160060 --- /dev/null +++ b/src/sqlancer/StandaloneReducer.java @@ -0,0 +1,140 @@ +package sqlancer; + +import java.io.FileWriter; +import java.io.PrintWriter; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.Query; + +/** + * A standalone tool to reduce bug-triggering SQL statements using the delta debugging algorithm. + */ +public class StandaloneReducer { + private int partitionNum = 2; + private final StateToReproduce originalState; + private final DatabaseProvider databaseProvider; + private final Path outputPath; + + public StandaloneReducer(Path inputPath, Path outputPath) throws Exception { + this.originalState = StateToReproduce.deserialize(inputPath); + this.databaseProvider = originalState.getDatabaseProvider(); + if (this.databaseProvider == null) { + throw new IllegalStateException("Failed to get database provider from .ser file"); + } + this.outputPath = outputPath != null ? outputPath + : Paths.get(inputPath.toString().replaceAll("\\.ser$", ".sql")); + } + + /** + * Performs the main reduction algorithm using partition-based delta debugging. + * + * @return List of reduced SQL statements that still trigger bugs. + */ + public List> reduce() throws Exception { + List> queries = new ArrayList<>(originalState.getStatements()); + if (queries.size() <= 1) { + return queries; + } + + partitionNum = 2; + while (queries.size() >= 2) { + boolean changedInThisPass = false; + List> result = tryReduction(queries); + + if (result.size() < queries.size()) { + queries = result; + changedInThisPass = true; + } + + if (changedInThisPass) { + partitionNum = 2; + } else { + if (partitionNum >= queries.size()) { + break; + } + partitionNum = Math.min(partitionNum * 2, queries.size()); + } + } + + try (PrintWriter writer = new PrintWriter(new FileWriter(outputPath.toFile()))) { + for (Query query : queries) { + writer.println(query.getQueryString()); + } + } + System.out.println("Reduction completed successfully! SQL statements written to: " + outputPath.toString()); + System.out.println("Final size: " + queries.size() + " statements (" + + String.format("%.1f", (1.0 - (double) queries.size() / originalState.getStatements().size()) * 100) + + "% reduction)"); + + return queries; + } + + private List> tryReduction(List> queries) throws Exception { + int start = 0; + int subLength = queries.size() / partitionNum; + + while (start < queries.size()) { + List> candidateQueries = new ArrayList<>(queries); + int endPoint = Math.min(start + subLength, candidateQueries.size()); + candidateQueries.subList(start, endPoint).clear(); + + if (testExceptionStillExists(candidateQueries)) { + return candidateQueries; + } + + start += subLength; + } + + return queries; + } + + // Test if bug still exists with reduced query set + @SuppressWarnings("unchecked") + private , O extends DBMSSpecificOptions, C extends SQLancerDBConnection> boolean testExceptionStillExists( + List> queries) { + try { + DatabaseProvider typedProvider = (DatabaseProvider) databaseProvider; + G globalState = typedProvider.getGlobalStateClass().getDeclaredConstructor().newInstance(); + + try (C connection = typedProvider.createDatabase(globalState)) { + globalState.setConnection(connection); + for (Query query : queries) { + try { + Query typedQuery = (Query) query; + typedQuery.execute(globalState); + } catch (Throwable e) { + // Any exception not declared as an expected error by the query indicates that an (unexpected) + // exception still exists + return true; + } + } + // No exception occurred + return false; + } + } catch (Throwable e) { + return true; + } + } + + public static void main(String[] args) { + try { + if (args.length == 0) { + System.err.println( + "Usage: java -cp target/sqlancer-2.0.0.jar sqlancer.StandaloneReducer [output-file]"); + System.exit(1); + } + Path inputPath = Paths.get(args[0]); + Path outputPath = args.length > 1 ? Paths.get(args[1]) : null; + + StandaloneReducer reducer = new StandaloneReducer(inputPath, outputPath); + reducer.reduce(); + } catch (Throwable e) { + System.err.println("ERROR: " + e.getMessage()); + e.printStackTrace(); + System.exit(1); + } + } +} diff --git a/src/sqlancer/StateToReproduce.java b/src/sqlancer/StateToReproduce.java index f6408ce58..e44d0ccf6 100644 --- a/src/sqlancer/StateToReproduce.java +++ b/src/sqlancer/StateToReproduce.java @@ -1,19 +1,26 @@ package sqlancer; import java.io.Closeable; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; import java.util.List; import sqlancer.common.query.Query; -public class StateToReproduce { +public class StateToReproduce implements Serializable { + private static final long serialVersionUID = 1L; - private final List> statements = new ArrayList<>(); + private List> statements = new ArrayList<>(); private final String databaseName; - private final DatabaseProvider databaseProvider; + private transient DatabaseProvider databaseProvider; public String databaseVersion; @@ -21,7 +28,7 @@ public class StateToReproduce { String exception; - public OracleRunReproductionState localState; + public transient OracleRunReproductionState localState; public StateToReproduce(String databaseName, DatabaseProvider databaseProvider) { this.databaseName = databaseName; @@ -40,6 +47,10 @@ public String getDatabaseVersion() { return databaseVersion; } + public DatabaseProvider getDatabaseProvider() { + return databaseProvider; + } + /** * Logs the statement string without executing the corresponding statement. * @@ -70,6 +81,9 @@ public List> getStatements() { return Collections.unmodifiableList(statements); } + /** + * @deprecated + */ @Deprecated public void commentStatements() { for (int i = 0; i < statements.size(); i++) { @@ -100,7 +114,7 @@ public class OracleRunReproductionState implements Closeable { private final List> statements = new ArrayList<>(); - public boolean success; + private boolean success; public OracleRunReproductionState() { StateToReproduce.this.localState = this; @@ -128,4 +142,47 @@ public OracleRunReproductionState createLocalState() { return new OracleRunReproductionState(); } + public void serialize(Path path) { + try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(path))) { + oos.writeObject(this); + } catch (IOException e) { + throw new AssertionError(e); + } + } + + public static StateToReproduce deserialize(Path path) { + try (ObjectInputStream ois = new ObjectInputStream(Files.newInputStream(path))) { + return (StateToReproduce) ois.readObject(); + } catch (IOException | ClassNotFoundException e) { + throw new AssertionError(e); + } + } + + private void writeObject(ObjectOutputStream out) throws IOException { + out.defaultWriteObject(); + + out.writeObject(this.databaseProvider != null ? this.databaseProvider.getDBMSName() : null); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + String dbmsName = (String) in.readObject(); + + DatabaseProvider provider = null; + if (dbmsName != null) { + List> providers = Main.getDBMSProviders(); + for (DatabaseProvider p : providers) { + if (p.getDBMSName().equals(dbmsName)) { + provider = p; + break; + } + } + } + this.databaseProvider = provider; + } + + public void setStatements(List> statements) { + this.statements = statements; + } + } diff --git a/src/sqlancer/StatementExecutor.java b/src/sqlancer/StatementExecutor.java index 26265d136..4f8f48b8f 100644 --- a/src/sqlancer/StatementExecutor.java +++ b/src/sqlancer/StatementExecutor.java @@ -70,7 +70,7 @@ public void executeStatements() throws Exception { success = globalState.executeStatement(query); } while (nextAction.canBeRetried() && !success && nrTries++ < globalState.getOptions().getNrStatementRetryCount()); - } catch (IgnoreMeException e) { + } catch (IgnoreMeException ignored) { } if (query != null && query.couldAffectSchema()) { diff --git a/src/sqlancer/StatementReducer.java b/src/sqlancer/StatementReducer.java new file mode 100644 index 000000000..e066aca84 --- /dev/null +++ b/src/sqlancer/StatementReducer.java @@ -0,0 +1,146 @@ +package sqlancer; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.Query; + +public class StatementReducer, O extends DBMSSpecificOptions, C extends SQLancerDBConnection> + implements Reducer { + private final DatabaseProvider provider; + private boolean observedChange; + private int partitionNum; + + private long currentReduceSteps; + private long currentReduceTime; + + private long maxReduceSteps; + private long maxReduceTime; + + Instant timeOfReductionBegins; + + public StatementReducer(DatabaseProvider provider) { + this.provider = provider; + } + + private boolean hasNotReachedLimit(long curr, long limit) { + if (limit == MainOptions.NO_REDUCE_LIMIT) { + return true; + } + return curr < limit; + } + + @SuppressWarnings("unchecked") + @Override + public void reduce(G state, Reproducer reproducer, G newGlobalState) throws Exception { + + maxReduceTime = state.getOptions().getMaxStatementReduceTime(); + maxReduceSteps = state.getOptions().getMaxStatementReduceSteps(); + + List> knownToReproduceBugStatements = new ArrayList<>(); + for (Query stat : state.getState().getStatements()) { + knownToReproduceBugStatements.add((Query) stat); + } + + // System.out.println("Starting query:"); + // Main.StateLogger logger = newGlobalState.getLogger(); + // printQueries(knownToReproduceBugStatements); + // System.out.println(); + + if (knownToReproduceBugStatements.size() <= 1) { + return; + } + + timeOfReductionBegins = Instant.now(); + currentReduceSteps = 0; + currentReduceTime = 0; + partitionNum = 2; + + while (knownToReproduceBugStatements.size() >= 2 && hasNotReachedLimit(currentReduceSteps, maxReduceSteps) + && hasNotReachedLimit(currentReduceTime, maxReduceTime)) { + observedChange = false; + + knownToReproduceBugStatements = tryReduction(state, reproducer, newGlobalState, + knownToReproduceBugStatements); + + if (!observedChange) { + if (partitionNum == knownToReproduceBugStatements.size()) { + break; + } + // increase the search granularity + partitionNum = Math.min(partitionNum * 2, knownToReproduceBugStatements.size()); + } + } + + // System.out.println("Reduced query:"); + // printQueries(knownToReproduceBugStatements); + newGlobalState.getState().setStatements(new ArrayList<>(knownToReproduceBugStatements)); + newGlobalState.getLogger().logReduced(newGlobalState.getState()); + + } + + private List> tryReduction(G state, // NOPMD + Reproducer reproducer, G newGlobalState, List> knownToReproduceBugStatements) throws Exception { + + List> statements = knownToReproduceBugStatements; + + int start = 0; + int subLength = statements.size() / partitionNum; + while (start < statements.size()) { + // newStatements = candidate[:start] + candidate[start+subLength:] + // in other word, remove [start, start+subLength) from candidates + try (C con2 = provider.createDatabase(newGlobalState)) { + newGlobalState.setConnection(con2); + List> candidateStatements = new ArrayList<>(statements); + int endPoint = Math.min(start + subLength, candidateStatements.size()); + candidateStatements.subList(start, endPoint).clear(); + newGlobalState.getState().setStatements(new ArrayList<>(candidateStatements)); + + for (Query s : candidateStatements) { + try { + s.execute(newGlobalState); + } catch (Throwable ignoredException) { + // ignore + } + } + try { + if (reproducer.bugStillTriggers(newGlobalState)) { + observedChange = true; + statements = candidateStatements; + partitionNum = Math.max(partitionNum - 1, 2); + // reproducer.outputHook((SQLite3GlobalState) newGlobalState); + newGlobalState.getLogger().logReduced(newGlobalState.getState()); + break; + + } + } catch (Throwable ignoredException) { + + } + } catch (Exception e) { + e.printStackTrace(); + } + + currentReduceSteps++; + Instant currentInstant = Instant.now(); + + currentReduceTime = Duration.between(timeOfReductionBegins, currentInstant).getSeconds(); + if (!hasNotReachedLimit(currentReduceSteps, maxReduceSteps) + || !hasNotReachedLimit(currentReduceTime, maxReduceTime)) { + return statements; + } + start = start + subLength; + } + return statements; + } + + @SuppressWarnings("unused") + private void printQueries(List> statements) { + System.out.println("==============================="); + for (Query q : statements) { + System.out.println(q.getLogString()); + } + System.out.println("==============================="); + } +} diff --git a/src/sqlancer/arangodb/ArangoDBComparatorHelper.java b/src/sqlancer/arangodb/ArangoDBComparatorHelper.java deleted file mode 100644 index 2a00a312d..000000000 --- a/src/sqlancer/arangodb/ArangoDBComparatorHelper.java +++ /dev/null @@ -1,73 +0,0 @@ -package sqlancer.arangodb; - -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import com.arangodb.entity.BaseDocument; - -import sqlancer.IgnoreMeException; -import sqlancer.Main; -import sqlancer.arangodb.query.ArangoDBSelectQuery; -import sqlancer.common.query.ExpectedErrors; - -public final class ArangoDBComparatorHelper { - - private ArangoDBComparatorHelper() { - - } - - public static List getResultSetAsDocumentList(ArangoDBSelectQuery query, - ArangoDBProvider.ArangoDBGlobalState state) throws Exception { - ExpectedErrors errors = query.getExpectedErrors(); - List result; - try { - query.executeAndGet(state); - Main.nrSuccessfulActions.addAndGet(1); - result = query.getResultSet(); - return result; - } catch (Exception e) { - if (e instanceof IgnoreMeException) { - throw e; - } - Main.nrUnsuccessfulActions.addAndGet(1); - if (e.getMessage() == null) { - throw new AssertionError(query.getLogString(), e); - } - if (errors.errorIsExpected(e.getMessage())) { - throw new IgnoreMeException(); - } - throw new AssertionError(query.getLogString(), e); - } - - } - - public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, - ArangoDBSelectQuery originalQuery) { - if (resultSet.size() != secondResultSet.size()) { - String assertionMessage = String.format("The Size of the result sets mismatch (%d and %d)!\n%s", - resultSet.size(), secondResultSet.size(), originalQuery.getLogString()); - throw new AssertionError(assertionMessage); - } - Set firstHashSet = new HashSet<>(resultSet); - Set secondHashSet = new HashSet<>(secondResultSet); - - if (!firstHashSet.equals(secondHashSet)) { - Set firstResultSetMisses = new HashSet<>(firstHashSet); - firstResultSetMisses.removeAll(secondHashSet); - Set secondResultSetMisses = new HashSet<>(secondHashSet); - secondResultSetMisses.removeAll(firstHashSet); - StringBuilder firstMisses = new StringBuilder(); - for (BaseDocument document : firstResultSetMisses) { - firstMisses.append(document).append(" "); - } - StringBuilder secondMisses = new StringBuilder(); - for (BaseDocument document : secondResultSetMisses) { - secondMisses.append(document).append(" "); - } - String assertMessage = String.format("The Content of the result sets mismatch!\n %s \n %s\n %s", - firstMisses.toString(), secondMisses.toString(), originalQuery.getLogString()); - throw new AssertionError(assertMessage); - } - } -} diff --git a/src/sqlancer/arangodb/ArangoDBConnection.java b/src/sqlancer/arangodb/ArangoDBConnection.java deleted file mode 100644 index b3e5b85d3..000000000 --- a/src/sqlancer/arangodb/ArangoDBConnection.java +++ /dev/null @@ -1,31 +0,0 @@ -package sqlancer.arangodb; - -import com.arangodb.ArangoDB; -import com.arangodb.ArangoDatabase; - -import sqlancer.SQLancerDBConnection; - -public class ArangoDBConnection implements SQLancerDBConnection { - - private final ArangoDB client; - private final ArangoDatabase database; - - public ArangoDBConnection(ArangoDB client, ArangoDatabase database) { - this.client = client; - this.database = database; - } - - @Override - public String getDatabaseVersion() throws Exception { - return client.getVersion().getVersion(); - } - - @Override - public void close() throws Exception { - client.shutdown(); - } - - public ArangoDatabase getDatabase() { - return database; - } -} diff --git a/src/sqlancer/arangodb/ArangoDBLoggableFactory.java b/src/sqlancer/arangodb/ArangoDBLoggableFactory.java deleted file mode 100644 index 927d9f320..000000000 --- a/src/sqlancer/arangodb/ArangoDBLoggableFactory.java +++ /dev/null @@ -1,40 +0,0 @@ -package sqlancer.arangodb; - -import java.util.Arrays; - -import sqlancer.common.log.Loggable; -import sqlancer.common.log.LoggableFactory; -import sqlancer.common.log.LoggedString; -import sqlancer.common.query.Query; - -public class ArangoDBLoggableFactory extends LoggableFactory { - @Override - protected Loggable createLoggable(String input, String suffix) { - return new LoggedString(input + suffix); - } - - @Override - public Query getQueryForStateToReproduce(String queryString) { - throw new UnsupportedOperationException(); - } - - @Override - public Query commentOutQuery(Query query) { - throw new UnsupportedOperationException(); - } - - @Override - protected Loggable infoToLoggable(String time, String databaseName, String databaseVersion, long seedValue) { - StringBuilder sb = new StringBuilder(); - sb.append("// Time: ").append(time).append("\n"); - sb.append("// Database: ").append(databaseName).append("\n"); - sb.append("// Database version: ").append(databaseVersion).append("\n"); - sb.append("// seed value: ").append(seedValue).append("\n"); - return new LoggedString(sb.toString()); - } - - @Override - public Loggable convertStacktraceToLoggable(Throwable throwable) { - return new LoggedString(Arrays.toString(throwable.getStackTrace()) + "\n" + throwable.getMessage()); - } -} diff --git a/src/sqlancer/arangodb/ArangoDBOptions.java b/src/sqlancer/arangodb/ArangoDBOptions.java deleted file mode 100644 index b821fe8c3..000000000 --- a/src/sqlancer/arangodb/ArangoDBOptions.java +++ /dev/null @@ -1,50 +0,0 @@ -package sqlancer.arangodb; - -import static sqlancer.arangodb.ArangoDBOptions.ArangoDBOracleFactory.QUERY_PARTITIONING; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import com.beust.jcommander.Parameter; -import com.beust.jcommander.Parameters; - -import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.arangodb.ArangoDBProvider.ArangoDBGlobalState; -import sqlancer.arangodb.test.ArangoDBQueryPartitioningWhereTester; -import sqlancer.common.oracle.CompositeTestOracle; -import sqlancer.common.oracle.TestOracle; - -@Parameters(commandDescription = "ArangoDB (experimental)") -public class ArangoDBOptions implements DBMSSpecificOptions { - - @Parameter(names = "--oracle") - public List oracles = Arrays.asList(QUERY_PARTITIONING); - - @Parameter(names = "--test-random-type-inserts", description = "Insert random types instead of schema types.") - public boolean testRandomTypeInserts; - - @Parameter(names = "--max-number-indexes", description = "The maximum number of indexes used.", arity = 1) - public int maxNumberIndexes = 15; - - @Parameter(names = "--with-optimizer-rule-tests", description = "Adds an additional query, where a random set" - + "of optimizer rules are disabled.", arity = 1) - public boolean withOptimizerRuleTests; - - @Override - public List getTestOracleFactory() { - return oracles; - } - - public enum ArangoDBOracleFactory implements OracleFactory { - QUERY_PARTITIONING { - @Override - public TestOracle create(ArangoDBGlobalState globalState) throws Exception { - List oracles = new ArrayList<>(); - oracles.add(new ArangoDBQueryPartitioningWhereTester(globalState)); - return new CompositeTestOracle(oracles, globalState); - } - } - } -} diff --git a/src/sqlancer/arangodb/ArangoDBProvider.java b/src/sqlancer/arangodb/ArangoDBProvider.java deleted file mode 100644 index 7bdc2fa01..000000000 --- a/src/sqlancer/arangodb/ArangoDBProvider.java +++ /dev/null @@ -1,137 +0,0 @@ -package sqlancer.arangodb; - -import java.util.ArrayList; -import java.util.List; - -import com.arangodb.ArangoDB; -import com.arangodb.ArangoDatabase; -import com.google.auto.service.AutoService; - -import sqlancer.AbstractAction; -import sqlancer.DatabaseProvider; -import sqlancer.ExecutionTimer; -import sqlancer.GlobalState; -import sqlancer.IgnoreMeException; -import sqlancer.ProviderAdapter; -import sqlancer.Randomly; -import sqlancer.StatementExecutor; -import sqlancer.arangodb.gen.ArangoDBCreateIndexGenerator; -import sqlancer.arangodb.gen.ArangoDBInsertGenerator; -import sqlancer.arangodb.gen.ArangoDBTableGenerator; -import sqlancer.common.log.LoggableFactory; -import sqlancer.common.query.Query; - -@AutoService(DatabaseProvider.class) -public class ArangoDBProvider - extends ProviderAdapter { - - public ArangoDBProvider() { - super(ArangoDBGlobalState.class, ArangoDBOptions.class); - } - - enum Action implements AbstractAction { - INSERT(ArangoDBInsertGenerator::getQuery), CREATE_INDEX(ArangoDBCreateIndexGenerator::getQuery); - - private final ArangoDBQueryProvider queryProvider; - - Action(ArangoDBQueryProvider queryProvider) { - this.queryProvider = queryProvider; - } - - @Override - public Query getQuery(ArangoDBGlobalState globalState) throws Exception { - return queryProvider.getQuery(globalState); - } - } - - private static int mapActions(ArangoDBGlobalState globalState, Action a) { - Randomly r = globalState.getRandomly(); - switch (a) { - case INSERT: - return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); - case CREATE_INDEX: - return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumberIndexes); - default: - throw new AssertionError(a); - } - } - - public static class ArangoDBGlobalState extends GlobalState { - - private final List schemaTables = new ArrayList<>(); - - public void addTable(ArangoDBSchema.ArangoDBTable table) { - schemaTables.add(table); - } - - @Override - protected void executeEpilogue(Query q, boolean success, ExecutionTimer timer) throws Exception { - boolean logExecutionTime = getOptions().logExecutionTime(); - if (success && getOptions().printSucceedingStatements()) { - System.out.println(q.getLogString()); - } - if (logExecutionTime) { - getLogger().writeCurrent("//" + timer.end().asString()); - } - if (q.couldAffectSchema()) { - updateSchema(); - } - } - - @Override - protected ArangoDBSchema readSchema() throws Exception { - return new ArangoDBSchema(schemaTables); - } - } - - @Override - protected void checkViewsAreValid(ArangoDBGlobalState globalState) { - - } - - @Override - public void generateDatabase(ArangoDBGlobalState globalState) throws Exception { - for (int i = 0; i < Randomly.fromOptions(4, 5, 6); i++) { - boolean success; - do { - ArangoDBQueryAdapter queryAdapter = new ArangoDBTableGenerator().getQuery(globalState); - success = globalState.executeStatement(queryAdapter); - } while (!success); - } - StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), - ArangoDBProvider::mapActions, (q) -> { - if (globalState.getSchema().getDatabaseTables().isEmpty()) { - throw new IgnoreMeException(); - } - }); - se.executeStatements(); - } - - @Override - public ArangoDBConnection createDatabase(ArangoDBGlobalState globalState) throws Exception { - ArangoDB arangoDB = new ArangoDB.Builder().user(globalState.getOptions().getUserName()) - .password(globalState.getOptions().getPassword()).build(); - ArangoDatabase database = arangoDB.db(globalState.getDatabaseName()); - try { - database.drop(); - // When the database does not exist, an ArangoDB exception is thrown. Since we are not sure - // if this is the first time the database is used, the simplest is dropping it and ignoring - // the exception. - } catch (Exception ignored) { - - } - arangoDB.createDatabase(globalState.getDatabaseName()); - database = arangoDB.db(globalState.getDatabaseName()); - return new ArangoDBConnection(arangoDB, database); - } - - @Override - public String getDBMSName() { - return "arangodb"; - } - - @Override - public LoggableFactory getLoggableFactory() { - return new ArangoDBLoggableFactory(); - } -} diff --git a/src/sqlancer/arangodb/ArangoDBQueryAdapter.java b/src/sqlancer/arangodb/ArangoDBQueryAdapter.java deleted file mode 100644 index 34cdb3709..000000000 --- a/src/sqlancer/arangodb/ArangoDBQueryAdapter.java +++ /dev/null @@ -1,16 +0,0 @@ -package sqlancer.arangodb; - -import sqlancer.common.query.Query; - -public abstract class ArangoDBQueryAdapter extends Query { - @Override - public String getQueryString() { - // Should not be called as it is used only in SQL dependent classes - throw new UnsupportedOperationException(); - } - - @Override - public String getUnterminatedQueryString() { - throw new UnsupportedOperationException(); - } -} diff --git a/src/sqlancer/arangodb/ArangoDBQueryProvider.java b/src/sqlancer/arangodb/ArangoDBQueryProvider.java deleted file mode 100644 index 94a4ffda3..000000000 --- a/src/sqlancer/arangodb/ArangoDBQueryProvider.java +++ /dev/null @@ -1,6 +0,0 @@ -package sqlancer.arangodb; - -@FunctionalInterface -public interface ArangoDBQueryProvider { - ArangoDBQueryAdapter getQuery(S globalState) throws Exception; -} diff --git a/src/sqlancer/arangodb/ArangoDBSchema.java b/src/sqlancer/arangodb/ArangoDBSchema.java deleted file mode 100644 index 35e251b8b..000000000 --- a/src/sqlancer/arangodb/ArangoDBSchema.java +++ /dev/null @@ -1,70 +0,0 @@ -package sqlancer.arangodb; - -import java.util.Collections; -import java.util.List; - -import sqlancer.Randomly; -import sqlancer.common.schema.AbstractSchema; -import sqlancer.common.schema.AbstractTable; -import sqlancer.common.schema.AbstractTableColumn; -import sqlancer.common.schema.AbstractTables; -import sqlancer.common.schema.TableIndex; - -public class ArangoDBSchema extends AbstractSchema { - - public enum ArangoDBDataType { - INTEGER, DOUBLE, STRING, BOOLEAN; - - public static ArangoDBDataType getRandom() { - return Randomly.fromOptions(values()); - } - } - - public static class ArangoDBColumn extends AbstractTableColumn { - - private final boolean isId; - private final boolean isNullable; - - public ArangoDBColumn(String name, ArangoDBDataType type, boolean isId, boolean isNullable) { - super(name, null, type); - this.isId = isId; - this.isNullable = isNullable; - } - - public boolean isId() { - return isId; - } - - public boolean isNullable() { - return isNullable; - } - } - - public ArangoDBSchema(List databaseTables) { - super(databaseTables); - } - - public static class ArangoDBTables extends AbstractTables { - - public ArangoDBTables(List tables) { - super(tables); - } - } - - public static class ArangoDBTable - extends AbstractTable { - - public ArangoDBTable(String name, List columns, boolean isView) { - super(name, columns, Collections.emptyList(), isView); - } - - @Override - public long getNrRows(ArangoDBProvider.ArangoDBGlobalState globalState) { - throw new UnsupportedOperationException(); - } - } - - public ArangoDBTables getRandomTableNonEmptyTables() { - return new ArangoDBTables(Randomly.nonEmptySubset(getDatabaseTables())); - } -} diff --git a/src/sqlancer/arangodb/ast/ArangoDBConstant.java b/src/sqlancer/arangodb/ast/ArangoDBConstant.java deleted file mode 100644 index 351dbd822..000000000 --- a/src/sqlancer/arangodb/ast/ArangoDBConstant.java +++ /dev/null @@ -1,108 +0,0 @@ -package sqlancer.arangodb.ast; - -import com.arangodb.entity.BaseDocument; - -import sqlancer.common.ast.newast.Node; - -public abstract class ArangoDBConstant implements Node { - private ArangoDBConstant() { - - } - - public abstract void setValueInDocument(BaseDocument document, String key); - - public abstract Object getValue(); - - public static class ArangoDBIntegerConstant extends ArangoDBConstant { - - private final int value; - - public ArangoDBIntegerConstant(int value) { - this.value = value; - } - - @Override - public void setValueInDocument(BaseDocument document, String key) { - document.addAttribute(key, value); - } - - @Override - public Object getValue() { - return value; - } - } - - public static Node createIntegerConstant(int value) { - return new ArangoDBIntegerConstant(value); - } - - public static class ArangoDBStringConstant extends ArangoDBConstant { - private final String value; - - public ArangoDBStringConstant(String value) { - this.value = value; - } - - @Override - public void setValueInDocument(BaseDocument document, String key) { - document.addAttribute(key, value); - } - - @Override - public Object getValue() { - return "'" + value.replace("\\", "\\\\").replace("'", "\\'") + "'"; - } - } - - public static Node createStringConstant(String value) { - return new ArangoDBStringConstant(value); - } - - public static class ArangoDBBooleanConstant extends ArangoDBConstant { - private final boolean value; - - public ArangoDBBooleanConstant(boolean value) { - this.value = value; - } - - @Override - public void setValueInDocument(BaseDocument document, String key) { - document.addAttribute(key, value); - } - - @Override - public Object getValue() { - return value; - } - } - - public static Node createBooleanConstant(boolean value) { - return new ArangoDBBooleanConstant(value); - } - - public static class ArangoDBDoubleConstant extends ArangoDBConstant { - private final double value; - - public ArangoDBDoubleConstant(double value) { - if (Double.isInfinite(value) || Double.isNaN(value)) { - this.value = 0.0; - } else { - this.value = value; - } - } - - @Override - public void setValueInDocument(BaseDocument document, String key) { - document.addAttribute(key, value); - } - - @Override - public Object getValue() { - return value; - } - } - - public static Node createDoubleConstant(double value) { - return new ArangoDBDoubleConstant(value); - } -} diff --git a/src/sqlancer/arangodb/ast/ArangoDBExpression.java b/src/sqlancer/arangodb/ast/ArangoDBExpression.java deleted file mode 100644 index facbbfe9e..000000000 --- a/src/sqlancer/arangodb/ast/ArangoDBExpression.java +++ /dev/null @@ -1,4 +0,0 @@ -package sqlancer.arangodb.ast; - -public interface ArangoDBExpression { -} diff --git a/src/sqlancer/arangodb/ast/ArangoDBSelect.java b/src/sqlancer/arangodb/ast/ArangoDBSelect.java deleted file mode 100644 index 9fb91d553..000000000 --- a/src/sqlancer/arangodb/ast/ArangoDBSelect.java +++ /dev/null @@ -1,79 +0,0 @@ -package sqlancer.arangodb.ast; - -import java.util.List; - -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.common.ast.newast.Node; - -public class ArangoDBSelect implements Node { - private List fromColumns; - private List projectionColumns; - private boolean hasFilter; - private Node filterClause; - private boolean hasComputed; - private List> computedClause; - - public List getFromColumns() { - if (fromColumns == null || fromColumns.isEmpty()) { - throw new IllegalStateException(); - } - return fromColumns; - } - - public void setFromColumns(List fromColumns) { - if (fromColumns == null || fromColumns.isEmpty()) { - throw new IllegalStateException(); - } - this.fromColumns = fromColumns; - } - - public List getProjectionColumns() { - if (projectionColumns == null) { - throw new IllegalStateException(); - } - return projectionColumns; - } - - public void setProjectionColumns(List projectionColumns) { - if (projectionColumns == null) { - throw new IllegalStateException(); - } - this.projectionColumns = projectionColumns; - } - - public void setFilterClause(Node filterClause) { - if (filterClause == null) { - hasFilter = false; - this.filterClause = null; - return; - } - hasFilter = true; - this.filterClause = filterClause; - } - - public Node getFilterClause() { - return filterClause; - } - - public boolean hasFilter() { - return hasFilter; - } - - public void setComputedClause(List> computedColumns) { - if (computedColumns == null || computedColumns.isEmpty()) { - hasComputed = false; - this.computedClause = null; - return; - } - hasComputed = true; - this.computedClause = computedColumns; - } - - public List> getComputedClause() { - return computedClause; - } - - public boolean hasComputed() { - return hasComputed; - } -} diff --git a/src/sqlancer/arangodb/ast/ArangoDBUnsupportedPredicate.java b/src/sqlancer/arangodb/ast/ArangoDBUnsupportedPredicate.java deleted file mode 100644 index eabd25578..000000000 --- a/src/sqlancer/arangodb/ast/ArangoDBUnsupportedPredicate.java +++ /dev/null @@ -1,6 +0,0 @@ -package sqlancer.arangodb.ast; - -import sqlancer.common.ast.newast.Node; - -public class ArangoDBUnsupportedPredicate implements Node { -} diff --git a/src/sqlancer/arangodb/gen/ArangoDBComputedExpressionGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBComputedExpressionGenerator.java deleted file mode 100644 index 01e2e557a..000000000 --- a/src/sqlancer/arangodb/gen/ArangoDBComputedExpressionGenerator.java +++ /dev/null @@ -1,85 +0,0 @@ -package sqlancer.arangodb.gen; - -import sqlancer.Randomly; -import sqlancer.arangodb.ArangoDBProvider; -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.arangodb.ast.ArangoDBConstant; -import sqlancer.arangodb.ast.ArangoDBExpression; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.gen.UntypedExpressionGenerator; - -public class ArangoDBComputedExpressionGenerator - extends UntypedExpressionGenerator, ArangoDBSchema.ArangoDBColumn> { - private final ArangoDBProvider.ArangoDBGlobalState globalState; - - public ArangoDBComputedExpressionGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) { - this.globalState = globalState; - } - - @Override - public Node generateConstant() { - ArangoDBSchema.ArangoDBDataType dataType = ArangoDBSchema.ArangoDBDataType.getRandom(); - switch (dataType) { - case INTEGER: - return ArangoDBConstant.createIntegerConstant((int) globalState.getRandomly().getInteger()); - case BOOLEAN: - return ArangoDBConstant.createBooleanConstant(Randomly.getBoolean()); - case DOUBLE: - return ArangoDBConstant.createDoubleConstant(globalState.getRandomly().getDouble()); - case STRING: - return ArangoDBConstant.createStringConstant(globalState.getRandomly().getString()); - default: - throw new AssertionError(dataType); - } - } - - public enum ComputedFunction { - ADD(2, "+"), MINUS(2, "-"), MULTIPLY(2, "*"), DIVISION(2, "/"), MODULUS(2, "%"); - - private final int nrArgs; - private final String operatorName; - - ComputedFunction(int nrArgs, String operatorName) { - this.nrArgs = nrArgs; - this.operatorName = operatorName; - } - - public static ComputedFunction getRandom() { - return Randomly.fromOptions(values()); - } - - public int getNrArgs() { - return nrArgs; - } - - public String getOperatorName() { - return operatorName; - } - } - - @Override - protected Node generateExpression(int depth) { - if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { - return generateLeafNode(); - } - ComputedFunction function = ComputedFunction.getRandom(); - return new NewFunctionNode<>(generateExpressions(function.getNrArgs(), depth + 1), function); - } - - @Override - protected Node generateColumn() { - return new ColumnReferenceNode<>(Randomly.fromList(columns)); - } - - @Override - public Node negatePredicate(Node predicate) { - throw new UnsupportedOperationException(); - } - - @Override - public Node isNull(Node expr) { - throw new UnsupportedOperationException(); - } -} diff --git a/src/sqlancer/arangodb/gen/ArangoDBCreateIndexGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBCreateIndexGenerator.java deleted file mode 100644 index 6a1b872da..000000000 --- a/src/sqlancer/arangodb/gen/ArangoDBCreateIndexGenerator.java +++ /dev/null @@ -1,18 +0,0 @@ -package sqlancer.arangodb.gen; - -import sqlancer.arangodb.ArangoDBProvider; -import sqlancer.arangodb.ArangoDBQueryAdapter; -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.arangodb.query.ArangoDBCreateIndexQuery; - -public final class ArangoDBCreateIndexGenerator { - private ArangoDBCreateIndexGenerator() { - - } - - public static ArangoDBQueryAdapter getQuery(ArangoDBProvider.ArangoDBGlobalState globalState) { - ArangoDBSchema.ArangoDBTable randomTable = globalState.getSchema().getRandomTable(); - ArangoDBSchema.ArangoDBColumn column = randomTable.getRandomColumn(); - return new ArangoDBCreateIndexQuery(column); - } -} diff --git a/src/sqlancer/arangodb/gen/ArangoDBFilterExpressionGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBFilterExpressionGenerator.java deleted file mode 100644 index 1a2fc4b5e..000000000 --- a/src/sqlancer/arangodb/gen/ArangoDBFilterExpressionGenerator.java +++ /dev/null @@ -1,153 +0,0 @@ -package sqlancer.arangodb.gen; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import sqlancer.Randomly; -import sqlancer.arangodb.ArangoDBProvider; -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.arangodb.ast.ArangoDBConstant; -import sqlancer.arangodb.ast.ArangoDBExpression; -import sqlancer.arangodb.ast.ArangoDBUnsupportedPredicate; -import sqlancer.common.ast.BinaryOperatorNode; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.gen.UntypedExpressionGenerator; - -public class ArangoDBFilterExpressionGenerator - extends UntypedExpressionGenerator, ArangoDBSchema.ArangoDBColumn> { - - private final ArangoDBProvider.ArangoDBGlobalState globalState; - private int numberOfComputedVariables; - - private enum Expression { - BINARY_LOGICAL, UNARY_PREFIX, BINARY_COMPARISON - } - - public ArangoDBFilterExpressionGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) { - this.globalState = globalState; - } - - public void setNumberOfComputedVariables(int numberOfComputedVariables) { - this.numberOfComputedVariables = numberOfComputedVariables; - } - - @Override - public Node generateConstant() { - ArangoDBSchema.ArangoDBDataType dataType = ArangoDBSchema.ArangoDBDataType.getRandom(); - switch (dataType) { - case INTEGER: - return ArangoDBConstant.createIntegerConstant((int) globalState.getRandomly().getInteger()); - case BOOLEAN: - return ArangoDBConstant.createBooleanConstant(Randomly.getBoolean()); - case DOUBLE: - return ArangoDBConstant.createDoubleConstant(globalState.getRandomly().getDouble()); - case STRING: - return ArangoDBConstant.createStringConstant(globalState.getRandomly().getString()); - default: - throw new AssertionError(dataType); - } - } - - @Override - protected Node generateExpression(int depth) { - if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { - return generateLeafNode(); - } - List possibleOptions = new ArrayList<>(Arrays.asList(Expression.values())); - Expression expression = Randomly.fromList(possibleOptions); - switch (expression) { - case BINARY_COMPARISON: - BinaryOperatorNode.Operator op = ArangoDBBinaryComparisonOperator.getRandom(); - return new NewBinaryOperatorNode<>(generateExpression(depth + 1), generateExpression(depth + 1), op); - case UNARY_PREFIX: - return new NewUnaryPrefixOperatorNode<>(generateExpression(depth + 1), - ArangoDBUnaryPrefixOperator.getRandom()); - case BINARY_LOGICAL: - op = ArangoDBBinaryLogicalOperator.getRandom(); - return new NewBinaryOperatorNode<>(generateExpression(depth + 1), generateExpression(depth + 1), op); - default: - throw new AssertionError(expression); - } - } - - @Override - protected Node generateColumn() { - ArangoDBSchema.ArangoDBTable dummy = new ArangoDBSchema.ArangoDBTable("", new ArrayList<>(), false); - if (Randomly.getBoolean() || numberOfComputedVariables == 0) { - ArangoDBSchema.ArangoDBColumn column = Randomly.fromList(columns); - return new ColumnReferenceNode<>(column); - } else { - int maxNumber = globalState.getRandomly().getInteger(0, numberOfComputedVariables); - ArangoDBSchema.ArangoDBColumn column = new ArangoDBSchema.ArangoDBColumn("c" + maxNumber, - ArangoDBSchema.ArangoDBDataType.INTEGER, false, false); - column.setTable(dummy); - return new ColumnReferenceNode<>(column); - } - } - - @Override - public Node negatePredicate(Node predicate) { - return new NewUnaryPrefixOperatorNode<>(predicate, ArangoDBUnaryPrefixOperator.NOT); - } - - @Override - public Node isNull(Node expr) { - return new ArangoDBUnsupportedPredicate<>(); - } - - public enum ArangoDBBinaryComparisonOperator implements BinaryOperatorNode.Operator { - EQUALS("=="), NOT_EQUALS("!="), LESS_THAN("<"), LESS_OR_EQUAL("<="), GREATER_THAN(">"), GREATER_OR_EQUAL(">="); - - private final String representation; - - ArangoDBBinaryComparisonOperator(String representation) { - this.representation = representation; - } - - @Override - public String getTextRepresentation() { - return representation; - } - - public static ArangoDBBinaryComparisonOperator getRandom() { - return Randomly.fromOptions(values()); - } - } - - public enum ArangoDBUnaryPrefixOperator implements BinaryOperatorNode.Operator { - NOT("!"); - - private final String representation; - - ArangoDBUnaryPrefixOperator(String representation) { - this.representation = representation; - } - - @Override - public String getTextRepresentation() { - return representation; - } - - public static ArangoDBUnaryPrefixOperator getRandom() { - return Randomly.fromOptions(values()); - } - } - - public enum ArangoDBBinaryLogicalOperator implements BinaryOperatorNode.Operator { - AND, OR; - - @Override - public String getTextRepresentation() { - return toString(); - } - - public static BinaryOperatorNode.Operator getRandom() { - return Randomly.fromOptions(values()); - } - } - -} diff --git a/src/sqlancer/arangodb/gen/ArangoDBInsertGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBInsertGenerator.java deleted file mode 100644 index 3cfceeed4..000000000 --- a/src/sqlancer/arangodb/gen/ArangoDBInsertGenerator.java +++ /dev/null @@ -1,39 +0,0 @@ -package sqlancer.arangodb.gen; - -import com.arangodb.entity.BaseDocument; - -import sqlancer.arangodb.ArangoDBProvider; -import sqlancer.arangodb.ArangoDBQueryAdapter; -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.arangodb.query.ArangoDBConstantGenerator; -import sqlancer.arangodb.query.ArangoDBInsertQuery; - -public final class ArangoDBInsertGenerator { - - private final ArangoDBProvider.ArangoDBGlobalState globalState; - - private ArangoDBInsertGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) { - this.globalState = globalState; - } - - public static ArangoDBQueryAdapter getQuery(ArangoDBProvider.ArangoDBGlobalState globalState) { - return new ArangoDBInsertGenerator(globalState).generate(); - } - - private ArangoDBQueryAdapter generate() { - BaseDocument result = new BaseDocument(); - ArangoDBSchema.ArangoDBTable table = globalState.getSchema().getRandomTable(); - ArangoDBConstantGenerator constantGenerator = new ArangoDBConstantGenerator(globalState); - - for (int i = 0; i < table.getColumns().size(); i++) { - if (!globalState.getDbmsSpecificOptions().testRandomTypeInserts) { - constantGenerator.addRandomConstantWithType(result, table.getColumns().get(i).getName(), - table.getColumns().get(i).getType()); - } else { - constantGenerator.addRandomConstant(result, table.getColumns().get(i).getName()); - } - } - - return new ArangoDBInsertQuery(table, result); - } -} diff --git a/src/sqlancer/arangodb/gen/ArangoDBTableGenerator.java b/src/sqlancer/arangodb/gen/ArangoDBTableGenerator.java deleted file mode 100644 index 1236c3ce4..000000000 --- a/src/sqlancer/arangodb/gen/ArangoDBTableGenerator.java +++ /dev/null @@ -1,44 +0,0 @@ -package sqlancer.arangodb.gen; - -import java.util.ArrayList; -import java.util.List; - -import sqlancer.Randomly; -import sqlancer.arangodb.ArangoDBProvider; -import sqlancer.arangodb.ArangoDBQueryAdapter; -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.arangodb.query.ArangoDBCreateTableQuery; - -public class ArangoDBTableGenerator { - - private ArangoDBSchema.ArangoDBTable table; - private final List columnsToBeAdded = new ArrayList<>(); - - public ArangoDBQueryAdapter getQuery(ArangoDBProvider.ArangoDBGlobalState globalState) { - String tableName = globalState.getSchema().getFreeTableName(); - ArangoDBCreateTableQuery createTableQuery = new ArangoDBCreateTableQuery(tableName); - table = new ArangoDBSchema.ArangoDBTable(tableName, columnsToBeAdded, false); - for (int i = 0; i < Randomly.smallNumber() + 1; i++) { - String columnName = String.format("c%d", i); - createColumn(columnName); - } - globalState.addTable(table); - return createTableQuery; - } - - private ArangoDBSchema.ArangoDBDataType createColumn(String columnName) { - ArangoDBSchema.ArangoDBDataType dataType = ArangoDBSchema.ArangoDBDataType.getRandom(); - ArangoDBSchema.ArangoDBColumn newColumn = new ArangoDBSchema.ArangoDBColumn(columnName, dataType, false, false); - newColumn.setTable(table); - columnsToBeAdded.add(newColumn); - return dataType; - } - - public String getTableName() { - return table.getName(); - } - - public ArangoDBSchema.ArangoDBTable getGeneratedTable() { - return table; - } -} diff --git a/src/sqlancer/arangodb/query/ArangoDBConstantGenerator.java b/src/sqlancer/arangodb/query/ArangoDBConstantGenerator.java deleted file mode 100644 index 406e8adca..000000000 --- a/src/sqlancer/arangodb/query/ArangoDBConstantGenerator.java +++ /dev/null @@ -1,46 +0,0 @@ -package sqlancer.arangodb.query; - -import com.arangodb.entity.BaseDocument; - -import sqlancer.Randomly; -import sqlancer.arangodb.ArangoDBProvider; -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.arangodb.ast.ArangoDBConstant; - -public class ArangoDBConstantGenerator { - private final ArangoDBProvider.ArangoDBGlobalState globalState; - - public ArangoDBConstantGenerator(ArangoDBProvider.ArangoDBGlobalState globalState) { - this.globalState = globalState; - } - - public void addRandomConstant(BaseDocument document, String key) { - ArangoDBSchema.ArangoDBDataType type = ArangoDBSchema.ArangoDBDataType.getRandom(); - addRandomConstantWithType(document, key, type); - } - - public void addRandomConstantWithType(BaseDocument document, String key, ArangoDBSchema.ArangoDBDataType dataType) { - ArangoDBConstant constant; - switch (dataType) { - case STRING: - constant = new ArangoDBConstant.ArangoDBStringConstant(globalState.getRandomly().getString()); - constant.setValueInDocument(document, key); - return; - case DOUBLE: - constant = new ArangoDBConstant.ArangoDBDoubleConstant(globalState.getRandomly().getDouble()); - constant.setValueInDocument(document, key); - return; - case BOOLEAN: - constant = new ArangoDBConstant.ArangoDBBooleanConstant(Randomly.getBoolean()); - constant.setValueInDocument(document, key); - return; - case INTEGER: - constant = new ArangoDBConstant.ArangoDBIntegerConstant((int) globalState.getRandomly().getInteger()); - constant.setValueInDocument(document, key); - return; - default: - throw new AssertionError(dataType); - } - - } -} diff --git a/src/sqlancer/arangodb/query/ArangoDBCreateIndexQuery.java b/src/sqlancer/arangodb/query/ArangoDBCreateIndexQuery.java deleted file mode 100644 index 6c2cc1b75..000000000 --- a/src/sqlancer/arangodb/query/ArangoDBCreateIndexQuery.java +++ /dev/null @@ -1,54 +0,0 @@ -package sqlancer.arangodb.query; - -import java.util.Collections; - -import com.arangodb.ArangoCollection; - -import sqlancer.GlobalState; -import sqlancer.Main; -import sqlancer.arangodb.ArangoDBConnection; -import sqlancer.arangodb.ArangoDBQueryAdapter; -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.common.query.ExpectedErrors; - -public class ArangoDBCreateIndexQuery extends ArangoDBQueryAdapter { - - private final ArangoDBSchema.ArangoDBColumn column; - - public ArangoDBCreateIndexQuery(ArangoDBSchema.ArangoDBColumn column) { - this.column = column; - } - - @Override - public boolean couldAffectSchema() { - return false; - } - - @Override - public > boolean execute(G globalState, String... fills) - throws Exception { - try { - ArangoCollection collection = globalState.getConnection().getDatabase() - .collection(column.getTable().getName()); - collection.ensureHashIndex(Collections.singletonList(column.getName()), null); - Main.nrSuccessfulActions.addAndGet(1); - return true; - } catch (Exception e) { - Main.nrUnsuccessfulActions.addAndGet(1); - throw e; - } - } - - @Override - public ExpectedErrors getExpectedErrors() { - return new ExpectedErrors(); - } - - @Override - public String getLogString() { - StringBuilder stringBuilder = new StringBuilder(); - stringBuilder.append("db.").append(column.getTable().getName()) - .append(".ensureIndex({type: \"hash\", fields: [ \"").append(column.getName()).append("\" ]});"); - return stringBuilder.toString(); - } -} diff --git a/src/sqlancer/arangodb/query/ArangoDBCreateTableQuery.java b/src/sqlancer/arangodb/query/ArangoDBCreateTableQuery.java deleted file mode 100644 index 00b3276d0..000000000 --- a/src/sqlancer/arangodb/query/ArangoDBCreateTableQuery.java +++ /dev/null @@ -1,44 +0,0 @@ -package sqlancer.arangodb.query; - -import sqlancer.GlobalState; -import sqlancer.Main; -import sqlancer.arangodb.ArangoDBConnection; -import sqlancer.arangodb.ArangoDBQueryAdapter; -import sqlancer.common.query.ExpectedErrors; - -public class ArangoDBCreateTableQuery extends ArangoDBQueryAdapter { - - private final String tableName; - - public ArangoDBCreateTableQuery(String tableName) { - this.tableName = tableName; - } - - @Override - public boolean couldAffectSchema() { - return true; - } - - @Override - public > boolean execute(G globalState, String... fills) - throws Exception { - try { - globalState.getConnection().getDatabase().createCollection(tableName); - Main.nrSuccessfulActions.addAndGet(1); - return true; - } catch (Exception e) { - Main.nrUnsuccessfulActions.addAndGet(1); - throw e; - } - } - - @Override - public ExpectedErrors getExpectedErrors() { - return new ExpectedErrors(); - } - - @Override - public String getLogString() { - return "db._create(\"" + tableName + "\")"; - } -} diff --git a/src/sqlancer/arangodb/query/ArangoDBInsertQuery.java b/src/sqlancer/arangodb/query/ArangoDBInsertQuery.java deleted file mode 100644 index 9a3612062..000000000 --- a/src/sqlancer/arangodb/query/ArangoDBInsertQuery.java +++ /dev/null @@ -1,66 +0,0 @@ -package sqlancer.arangodb.query; - -import java.util.Map; - -import com.arangodb.entity.BaseDocument; - -import sqlancer.GlobalState; -import sqlancer.Main; -import sqlancer.arangodb.ArangoDBConnection; -import sqlancer.arangodb.ArangoDBQueryAdapter; -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.common.query.ExpectedErrors; - -public class ArangoDBInsertQuery extends ArangoDBQueryAdapter { - - private final ArangoDBSchema.ArangoDBTable table; - private final BaseDocument documentToBeInserted; - - public ArangoDBInsertQuery(ArangoDBSchema.ArangoDBTable table, BaseDocument documentToBeInserted) { - this.table = table; - this.documentToBeInserted = documentToBeInserted; - } - - @Override - public boolean couldAffectSchema() { - return true; - } - - @Override - public > boolean execute(G globalState, String... fills) - throws Exception { - try { - globalState.getConnection().getDatabase().collection(table.getName()).insertDocument(documentToBeInserted); - Main.nrSuccessfulActions.addAndGet(1); - return true; - } catch (Exception e) { - Main.nrUnsuccessfulActions.addAndGet(1); - throw e; - } - } - - @Override - public ExpectedErrors getExpectedErrors() { - return new ExpectedErrors(); - } - - @Override - public String getLogString() { - StringBuilder stringBuilder = new StringBuilder(); - stringBuilder.append("db._query(\"INSERT { "); - String filler = ""; - for (Map.Entry stringObjectEntry : documentToBeInserted.getProperties().entrySet()) { - stringBuilder.append(filler); - filler = ", "; - stringBuilder.append(stringObjectEntry.getKey()).append(": "); - Object value = stringObjectEntry.getValue(); - if (value instanceof String) { - stringBuilder.append("'").append(value).append("'"); - } else { - stringBuilder.append(value); - } - } - stringBuilder.append("} IN ").append(table.getName()).append("\")"); - return stringBuilder.toString(); - } -} diff --git a/src/sqlancer/arangodb/query/ArangoDBOptimizerRules.java b/src/sqlancer/arangodb/query/ArangoDBOptimizerRules.java deleted file mode 100644 index 835849b92..000000000 --- a/src/sqlancer/arangodb/query/ArangoDBOptimizerRules.java +++ /dev/null @@ -1,57 +0,0 @@ -package sqlancer.arangodb.query; - -import java.util.ArrayList; -import java.util.List; - -import sqlancer.Randomly; - -public class ArangoDBOptimizerRules { - - private final List allRules = new ArrayList<>(); - - public ArangoDBOptimizerRules() { - // SRC: - // https://www.arangodb.com/docs/stable/aql/execution-and-performance-optimizer.html#list-of-optimizer-rules - // Filtered out irrelevant ones - allRules.add("-fuse-filters"); - // allRules.add("-geo-index-optimizer"); - // allRules.add("-handle-arangosearch-views"); - // allRules.add("-inline-subqueries"); - allRules.add("-interchange-adjacent-enumerations"); - allRules.add("-late-document-materialization"); - // allRules.add("-late-document-materialization-arangosearch"); - allRules.add("-move-calculations-down"); - allRules.add("-move-calculations-up"); - allRules.add("-move-filters-into-enumerate"); - allRules.add("-move-filters-up"); - // allRules.add("-optimize-count"); - // allRules.add("-optimize-subqueries"); - // allRules.add("-optimize-traversals"); - // allRules.add("-patch-update-statements"); - allRules.add("-propagate-constant-attributes"); - allRules.add("-reduce-extraction-to-projection"); - // allRules.add("-remove-collect-variables"); - // allRules.add("-remove-data-modification-out-variables"); - allRules.add("-remove-filter-covered-by-index"); - // allRules.add("-remove-filter-covered-by-traversal"); - allRules.add("-remove-redundant-calculations"); - allRules.add("-remove-redundant-or"); - // allRules.add("-remove-redundant-path-var"); - // allRules.add("-remove-redundant-sorts"); - // allRules.add("-remove-sort-rand"); - allRules.add("-remove-unnecessary-calculations"); - allRules.add("-remove-unnecessary-filters"); - // allRules.add("-replace-function-with-index"); - allRules.add("-replace-or-with-in"); - allRules.add("-simplify-conditions"); - // allRules.add("-sort-in-values"); - // allRules.add("-sort-limit"); - // allRules.add("-splice-subqueries"); - // allRules.add("-use-index-for-sort"); - allRules.add("-use-indexes"); - } - - public List getRandomRules() { - return Randomly.subset(allRules); - } -} diff --git a/src/sqlancer/arangodb/query/ArangoDBSelectQuery.java b/src/sqlancer/arangodb/query/ArangoDBSelectQuery.java deleted file mode 100644 index 4725e4178..000000000 --- a/src/sqlancer/arangodb/query/ArangoDBSelectQuery.java +++ /dev/null @@ -1,88 +0,0 @@ -package sqlancer.arangodb.query; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; - -import com.arangodb.ArangoCursor; -import com.arangodb.entity.BaseDocument; -import com.arangodb.model.AqlQueryOptions; - -import sqlancer.GlobalState; -import sqlancer.arangodb.ArangoDBConnection; -import sqlancer.arangodb.ArangoDBQueryAdapter; -import sqlancer.common.query.ExpectedErrors; -import sqlancer.common.query.SQLancerResultSet; - -public class ArangoDBSelectQuery extends ArangoDBQueryAdapter { - - private final String query; - - private List optimizerRules; - - private List resultSet; - - public ArangoDBSelectQuery(String query) { - this.query = query; - optimizerRules = new ArrayList<>(); - } - - @Override - public boolean couldAffectSchema() { - return false; - } - - @Override - public > boolean execute(G globalState, String... fills) - throws Exception { - throw new UnsupportedOperationException(); - } - - @Override - public ExpectedErrors getExpectedErrors() { - return new ExpectedErrors(); - } - - @Override - public String getLogString() { - if (optimizerRules.isEmpty()) { - return "db._query(\"" + query + "\")"; - } else { - String rules = optimizerRules.stream().map(Object::toString).collect(Collectors.joining("\",\"")); - return "db._query(\"" + query + "\", null, { optimizer: { rules: [\"" + rules + "\"] } } )"; - } - } - - @Override - public > SQLancerResultSet executeAndGet(G globalState, - String... fills) throws Exception { - if (globalState.getOptions().logEachSelect()) { - globalState.getLogger().writeCurrent(this.getLogString()); - try { - globalState.getLogger().getCurrentFileWriter().flush(); - } catch (IOException e) { - e.printStackTrace(); - } - } - - ArangoCursor cursor; - if (optimizerRules.isEmpty()) { - cursor = globalState.getConnection().getDatabase().query(query, BaseDocument.class); - } else { - AqlQueryOptions options = new AqlQueryOptions(); - cursor = globalState.getConnection().getDatabase().query(query, options.rules(optimizerRules), - BaseDocument.class); - } - resultSet = cursor.asListRemaining(); - return null; - } - - public List getResultSet() { - return resultSet; - } - - public void excludeRandomOptRules() { - optimizerRules = new ArangoDBOptimizerRules().getRandomRules(); - } -} diff --git a/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningBase.java b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningBase.java deleted file mode 100644 index f583ed04f..000000000 --- a/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningBase.java +++ /dev/null @@ -1,67 +0,0 @@ -package sqlancer.arangodb.test; - -import java.util.ArrayList; -import java.util.List; - -import sqlancer.Randomly; -import sqlancer.arangodb.ArangoDBProvider; -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.arangodb.ast.ArangoDBExpression; -import sqlancer.arangodb.ast.ArangoDBSelect; -import sqlancer.arangodb.gen.ArangoDBComputedExpressionGenerator; -import sqlancer.arangodb.gen.ArangoDBFilterExpressionGenerator; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.gen.ExpressionGenerator; -import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; -import sqlancer.common.oracle.TestOracle; - -public class ArangoDBQueryPartitioningBase - extends TernaryLogicPartitioningOracleBase, ArangoDBProvider.ArangoDBGlobalState> - implements TestOracle { - - protected ArangoDBSchema schema; - protected List targetColumns; - protected ArangoDBFilterExpressionGenerator expressionGenerator; - protected ArangoDBSelect select; - protected int numberComputedColumns; - - protected ArangoDBQueryPartitioningBase(ArangoDBProvider.ArangoDBGlobalState state) { - super(state); - } - - @Override - protected ExpressionGenerator> getGen() { - return expressionGenerator; - } - - @Override - public void check() throws Exception { - numberComputedColumns = state.getRandomly().getInteger(0, 4); - schema = state.getSchema(); - generateTargetColumns(); - expressionGenerator = new ArangoDBFilterExpressionGenerator(state).setColumns(targetColumns); - expressionGenerator.setNumberOfComputedVariables(numberComputedColumns); - initializeTernaryPredicateVariants(); - select = new ArangoDBSelect<>(); - select.setFromColumns(targetColumns); - select.setProjectionColumns(Randomly.nonEmptySubset(targetColumns)); - generateComputedClause(); - } - - private void generateComputedClause() { - List> computedColumns = new ArrayList<>(); - ArangoDBComputedExpressionGenerator generator = new ArangoDBComputedExpressionGenerator(state); - generator.setColumns(targetColumns); - for (int i = 0; i < numberComputedColumns; i++) { - computedColumns.add(generator.generateExpression()); - } - select.setComputedClause(computedColumns); - } - - private void generateTargetColumns() { - ArangoDBSchema.ArangoDBTables targetTables; - targetTables = schema.getRandomTableNonEmptyTables(); - List allColumns = targetTables.getColumns(); - targetColumns = Randomly.nonEmptySubset(allColumns); - } -} diff --git a/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningWhereTester.java b/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningWhereTester.java deleted file mode 100644 index 6ad19fabf..000000000 --- a/src/sqlancer/arangodb/test/ArangoDBQueryPartitioningWhereTester.java +++ /dev/null @@ -1,46 +0,0 @@ -package sqlancer.arangodb.test; - -import static sqlancer.arangodb.ArangoDBComparatorHelper.assumeResultSetsAreEqual; -import static sqlancer.arangodb.ArangoDBComparatorHelper.getResultSetAsDocumentList; - -import java.util.List; - -import com.arangodb.entity.BaseDocument; - -import sqlancer.arangodb.ArangoDBProvider; -import sqlancer.arangodb.query.ArangoDBSelectQuery; -import sqlancer.arangodb.visitor.ArangoDBVisitor; - -public class ArangoDBQueryPartitioningWhereTester extends ArangoDBQueryPartitioningBase { - public ArangoDBQueryPartitioningWhereTester(ArangoDBProvider.ArangoDBGlobalState state) { - super(state); - } - - @Override - public void check() throws Exception { - super.check(); - select.setFilterClause(null); - - ArangoDBSelectQuery query = ArangoDBVisitor.asSelectQuery(select); - List firstResultSet = getResultSetAsDocumentList(query, state); - - select.setFilterClause(predicate); - query = ArangoDBVisitor.asSelectQuery(select); - List secondResultSet = getResultSetAsDocumentList(query, state); - - select.setFilterClause(negatedPredicate); - query = ArangoDBVisitor.asSelectQuery(select); - List thirdResultSet = getResultSetAsDocumentList(query, state); - - thirdResultSet.addAll(secondResultSet); - assumeResultSetsAreEqual(firstResultSet, thirdResultSet, query); - - if (state.getDbmsSpecificOptions().withOptimizerRuleTests) { - select.setFilterClause(predicate); - query = ArangoDBVisitor.asSelectQuery(select); - query.excludeRandomOptRules(); - List forthResultSet = getResultSetAsDocumentList(query, state); - assumeResultSetsAreEqual(secondResultSet, forthResultSet, query); - } - } -} diff --git a/src/sqlancer/arangodb/visitor/ArangoDBToQueryVisitor.java b/src/sqlancer/arangodb/visitor/ArangoDBToQueryVisitor.java deleted file mode 100644 index f82995d5e..000000000 --- a/src/sqlancer/arangodb/visitor/ArangoDBToQueryVisitor.java +++ /dev/null @@ -1,134 +0,0 @@ -package sqlancer.arangodb.visitor; - -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import sqlancer.arangodb.ArangoDBSchema; -import sqlancer.arangodb.ast.ArangoDBConstant; -import sqlancer.arangodb.ast.ArangoDBExpression; -import sqlancer.arangodb.ast.ArangoDBSelect; -import sqlancer.arangodb.gen.ArangoDBComputedExpressionGenerator; -import sqlancer.arangodb.query.ArangoDBSelectQuery; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; -import sqlancer.common.ast.newast.Node; - -public class ArangoDBToQueryVisitor extends ArangoDBVisitor { - - private final StringBuilder stringBuilder; - - public ArangoDBToQueryVisitor() { - stringBuilder = new StringBuilder(); - } - - @Override - protected void visit(ArangoDBSelect expression) { - generateFrom(expression); - generateComputed(expression); - generateFilter(expression); - generateProject(expression); - } - - private void generateFilter(ArangoDBSelect expression) { - if (expression.hasFilter()) { - stringBuilder.append("FILTER "); - visit(expression.getFilterClause()); - stringBuilder.append(" "); - } - } - - private void generateComputed(ArangoDBSelect expression) { - if (expression.hasComputed()) { - List> computedClause = expression.getComputedClause(); - int computedNumber = 0; - for (Node computedExpression : computedClause) { - stringBuilder.append("LET c").append(computedNumber).append(" = "); - visit(computedExpression); - stringBuilder.append(" "); - computedNumber++; - } - } - } - - @Override - protected void visit(ColumnReferenceNode expression) { - if (expression.getColumn().getTable().getName().equals("")) { - stringBuilder.append(expression.getColumn().getName()); - } else { - stringBuilder.append("r").append(expression.getColumn().getTable().getName()).append(".") - .append(expression.getColumn().getName()); - } - } - - @Override - protected void visit(ArangoDBConstant expression) { - stringBuilder.append(expression.getValue()); - } - - @Override - protected void visit(NewBinaryOperatorNode expression) { - stringBuilder.append("("); - visit(expression.getLeft()); - stringBuilder.append(" ").append(expression.getOperatorRepresentation()).append(" "); - visit(expression.getRight()); - stringBuilder.append(")"); - } - - @Override - protected void visit(NewUnaryPrefixOperatorNode expression) { - stringBuilder.append(expression.getOperatorRepresentation()).append("("); - visit(expression.getExpr()); - stringBuilder.append(")"); - } - - @Override - protected void visit(NewFunctionNode expression) { - if (!(expression.getFunc() instanceof ArangoDBComputedExpressionGenerator.ComputedFunction)) { - throw new UnsupportedOperationException(); - } - ArangoDBComputedExpressionGenerator.ComputedFunction function = (ArangoDBComputedExpressionGenerator.ComputedFunction) expression - .getFunc(); - // TODO: Support functions with a different number of arguments. - if (function.getNrArgs() != 2) { - throw new UnsupportedOperationException(); - } - stringBuilder.append("("); - visit(expression.getArgs().get(0)); - stringBuilder.append(" ").append(function.getOperatorName()).append(" "); - visit(expression.getArgs().get(1)); - stringBuilder.append(")"); - } - - private void generateFrom(ArangoDBSelect expression) { - List forColumns = expression.getFromColumns(); - Set tables = new HashSet<>(); - for (ArangoDBSchema.ArangoDBColumn column : forColumns) { - tables.add(column.getTable()); - } - - for (ArangoDBSchema.ArangoDBTable table : tables) { - stringBuilder.append("FOR r").append(table.getName()).append(" IN ").append(table.getName()).append(" "); - } - } - - private void generateProject(ArangoDBSelect expression) { - List projectColumns = expression.getProjectionColumns(); - stringBuilder.append("RETURN {"); - String filler = ""; - for (ArangoDBSchema.ArangoDBColumn column : projectColumns) { - stringBuilder.append(filler); - filler = ", "; - stringBuilder.append(column.getTable().getName()).append("_").append(column.getName()).append(": r") - .append(column.getTable().getName()).append(".").append(column.getName()); - } - stringBuilder.append("}"); - } - - public ArangoDBSelectQuery getQuery() { - return new ArangoDBSelectQuery(stringBuilder.toString()); - } - -} diff --git a/src/sqlancer/arangodb/visitor/ArangoDBVisitor.java b/src/sqlancer/arangodb/visitor/ArangoDBVisitor.java deleted file mode 100644 index f1db84cf5..000000000 --- a/src/sqlancer/arangodb/visitor/ArangoDBVisitor.java +++ /dev/null @@ -1,51 +0,0 @@ -package sqlancer.arangodb.visitor; - -import sqlancer.arangodb.ast.ArangoDBConstant; -import sqlancer.arangodb.ast.ArangoDBExpression; -import sqlancer.arangodb.ast.ArangoDBSelect; -import sqlancer.arangodb.query.ArangoDBSelectQuery; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; -import sqlancer.common.ast.newast.Node; - -public abstract class ArangoDBVisitor { - - protected abstract void visit(ArangoDBSelect expression); - - protected abstract void visit(ColumnReferenceNode expression); - - protected abstract void visit(ArangoDBConstant expression); - - protected abstract void visit(NewBinaryOperatorNode expression); - - protected abstract void visit(NewUnaryPrefixOperatorNode expression); - - protected abstract void visit(NewFunctionNode expression); - - @SuppressWarnings("unchecked") - public void visit(Node expressionNode) { - if (expressionNode instanceof ArangoDBSelect) { - visit((ArangoDBSelect) expressionNode); - } else if (expressionNode instanceof ColumnReferenceNode) { - visit((ColumnReferenceNode) expressionNode); - } else if (expressionNode instanceof ArangoDBConstant) { - visit((ArangoDBConstant) expressionNode); - } else if (expressionNode instanceof NewBinaryOperatorNode) { - visit((NewBinaryOperatorNode) expressionNode); - } else if (expressionNode instanceof NewUnaryPrefixOperatorNode) { - visit((NewUnaryPrefixOperatorNode) expressionNode); - } else if (expressionNode instanceof NewFunctionNode) { - visit((NewFunctionNode) expressionNode); - } else { - throw new AssertionError(expressionNode); - } - } - - public static ArangoDBSelectQuery asSelectQuery(Node expressionNode) { - ArangoDBToQueryVisitor visitor = new ArangoDBToQueryVisitor(); - visitor.visit(expressionNode); - return visitor.getQuery(); - } -} diff --git a/src/sqlancer/citus/CitusBugs.java b/src/sqlancer/citus/CitusBugs.java index bc81b9db6..a6f4910e6 100644 --- a/src/sqlancer/citus/CitusBugs.java +++ b/src/sqlancer/citus/CitusBugs.java @@ -30,6 +30,9 @@ public final class CitusBugs { // https://github.com/citusdata/citus/issues/4079 public static boolean bug4079 = true; + // https://github.com/citusdata/citus/issues/6298 + public static boolean bug6298 = true; + private CitusBugs() { } diff --git a/src/sqlancer/citus/CitusOptions.java b/src/sqlancer/citus/CitusOptions.java index cdeadb672..f2d1b0abc 100644 --- a/src/sqlancer/citus/CitusOptions.java +++ b/src/sqlancer/citus/CitusOptions.java @@ -1,22 +1,11 @@ package sqlancer.citus; -import java.sql.SQLException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import com.beust.jcommander.Parameter; -import sqlancer.OracleFactory; -import sqlancer.citus.oracle.CitusNoRECOracle; -import sqlancer.citus.oracle.tlp.CitusTLPAggregateOracle; -import sqlancer.citus.oracle.tlp.CitusTLPHavingOracle; -import sqlancer.citus.oracle.tlp.CitusTLPWhereOracle; -import sqlancer.common.oracle.CompositeTestOracle; -import sqlancer.common.oracle.TestOracle; -import sqlancer.postgres.PostgresGlobalState; import sqlancer.postgres.PostgresOptions; -import sqlancer.postgres.oracle.PostgresPivotedQuerySynthesisOracle; public class CitusOptions extends PostgresOptions { @@ -26,41 +15,4 @@ public class CitusOptions extends PostgresOptions { @Parameter(names = "--citusoracle", description = "Specifies which test oracle should be used for Citus extension to PostgreSQL") public List citusOracle = Arrays.asList(CitusOracleFactory.QUERY_PARTITIONING); - public enum CitusOracleFactory implements OracleFactory { - NOREC { - @Override - public TestOracle create(PostgresGlobalState globalState) throws SQLException { - CitusGlobalState citusGlobalState = (CitusGlobalState) globalState; - return new CitusNoRECOracle(citusGlobalState); - } - }, - PQS { - @Override - public TestOracle create(PostgresGlobalState globalState) throws SQLException { - return new PostgresPivotedQuerySynthesisOracle(globalState); - } - }, - HAVING { - - @Override - public TestOracle create(PostgresGlobalState globalState) throws SQLException { - CitusGlobalState citusGlobalState = (CitusGlobalState) globalState; - return new CitusTLPHavingOracle(citusGlobalState); - } - - }, - QUERY_PARTITIONING { - @Override - public TestOracle create(PostgresGlobalState globalState) throws SQLException { - CitusGlobalState citusGlobalState = (CitusGlobalState) globalState; - List oracles = new ArrayList<>(); - oracles.add(new CitusTLPWhereOracle(citusGlobalState)); - oracles.add(new CitusTLPHavingOracle(citusGlobalState)); - oracles.add(new CitusTLPAggregateOracle(citusGlobalState)); - return new CompositeTestOracle(oracles, globalState); - } - }; - - } - } diff --git a/src/sqlancer/citus/CitusOracleFactory.java b/src/sqlancer/citus/CitusOracleFactory.java new file mode 100644 index 000000000..b6e0a7f3e --- /dev/null +++ b/src/sqlancer/citus/CitusOracleFactory.java @@ -0,0 +1,70 @@ +package sqlancer.citus; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.OracleFactory; +import sqlancer.citus.gen.CitusCommon; +import sqlancer.citus.oracle.tlp.CitusTLPAggregateOracle; +import sqlancer.citus.oracle.tlp.CitusTLPHavingOracle; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.gen.PostgresCommon; +import sqlancer.postgres.gen.PostgresExpressionGenerator; +import sqlancer.postgres.oracle.PostgresPivotedQuerySynthesisOracle; + +public enum CitusOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create(PostgresGlobalState globalState) throws SQLException { + PostgresExpressionGenerator gen = new PostgresExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(PostgresCommon.getCommonExpressionErrors()) + .with(PostgresCommon.getCommonFetchErrors()) + .withRegex(PostgresCommon.getCommonExpressionRegexErrors()) + .with(CitusCommon.getCitusErrors().toArray(new String[0])).build(); + return new NoRECOracle<>(globalState, gen, errors); + } + }, + PQS { + @Override + public TestOracle create(PostgresGlobalState globalState) throws SQLException { + return new PostgresPivotedQuerySynthesisOracle(globalState); + } + }, + WHERE { + @Override + public TestOracle create(PostgresGlobalState globalState) throws SQLException { + PostgresExpressionGenerator gen = new PostgresExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(PostgresCommon.getCommonExpressionErrors()) + .with(PostgresCommon.getCommonFetchErrors()) + .withRegex(PostgresCommon.getCommonExpressionRegexErrors()).with(CitusCommon.getCitusErrors()) + .build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + HAVING { + @Override + public TestOracle create(PostgresGlobalState globalState) throws SQLException { + CitusGlobalState citusGlobalState = (CitusGlobalState) globalState; + return new CitusTLPHavingOracle(citusGlobalState); + } + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(PostgresGlobalState globalState) throws Exception { + CitusGlobalState citusGlobalState = (CitusGlobalState) globalState; + List> oracles = new ArrayList<>(); + oracles.add(WHERE.create(citusGlobalState)); + oracles.add(HAVING.create(citusGlobalState)); + oracles.add(new CitusTLPAggregateOracle(citusGlobalState)); + return new CompositeTestOracle(oracles, globalState); + } + }; + +} diff --git a/src/sqlancer/citus/CitusProvider.java b/src/sqlancer/citus/CitusProvider.java index aaebd6709..a11424f18 100644 --- a/src/sqlancer/citus/CitusProvider.java +++ b/src/sqlancer/citus/CitusProvider.java @@ -21,11 +21,15 @@ import sqlancer.citus.gen.CitusAlterTableGenerator; import sqlancer.citus.gen.CitusCommon; import sqlancer.citus.gen.CitusDeleteGenerator; +import sqlancer.citus.gen.CitusDiscardGenerator; import sqlancer.citus.gen.CitusIndexGenerator; import sqlancer.citus.gen.CitusInsertGenerator; +import sqlancer.citus.gen.CitusReindexGenerator; import sqlancer.citus.gen.CitusSetGenerator; import sqlancer.citus.gen.CitusTableGenerator; +import sqlancer.citus.gen.CitusTruncateGenerator; import sqlancer.citus.gen.CitusUpdateGenerator; +import sqlancer.citus.gen.CitusVacuumGenerator; import sqlancer.citus.gen.CitusViewGenerator; import sqlancer.common.DBMSCommon; import sqlancer.common.oracle.CompositeTestOracle; @@ -44,15 +48,11 @@ import sqlancer.postgres.gen.PostgresAnalyzeGenerator; import sqlancer.postgres.gen.PostgresClusterGenerator; import sqlancer.postgres.gen.PostgresCommentGenerator; -import sqlancer.postgres.gen.PostgresDiscardGenerator; import sqlancer.postgres.gen.PostgresDropIndexGenerator; import sqlancer.postgres.gen.PostgresNotifyGenerator; -import sqlancer.postgres.gen.PostgresReindexGenerator; import sqlancer.postgres.gen.PostgresSequenceGenerator; import sqlancer.postgres.gen.PostgresStatisticsGenerator; import sqlancer.postgres.gen.PostgresTransactionGenerator; -import sqlancer.postgres.gen.PostgresTruncateGenerator; -import sqlancer.postgres.gen.PostgresVacuumGenerator; @AutoService(DatabaseProvider.class) public class CitusProvider extends PostgresProvider { @@ -82,13 +82,13 @@ public enum Action implements AbstractAction { CREATE_STATISTICS(PostgresStatisticsGenerator::insert), // DROP_STATISTICS(PostgresStatisticsGenerator::remove), // DELETE(CitusDeleteGenerator::create), // - DISCARD(PostgresDiscardGenerator::create), // + DISCARD(CitusDiscardGenerator::create), // DROP_INDEX(PostgresDropIndexGenerator::create), // INSERT(CitusInsertGenerator::insert), // UPDATE(CitusUpdateGenerator::create), // - TRUNCATE(PostgresTruncateGenerator::create), // - VACUUM(PostgresVacuumGenerator::create), // - REINDEX(PostgresReindexGenerator::create), // + TRUNCATE(CitusTruncateGenerator::create), // + VACUUM(CitusVacuumGenerator::create), // + REINDEX(CitusReindexGenerator::create), // SET(CitusSetGenerator::create), // CREATE_INDEX(CitusIndexGenerator::generate), // SET_CONSTRAINTS((g) -> { @@ -319,15 +319,16 @@ public void generateDatabase(PostgresGlobalState globalState) throws Exception { } @Override - protected TestOracle getTestOracle(PostgresGlobalState globalState) throws SQLException { - List oracles = ((CitusOptions) globalState.getDbmsSpecificOptions()).citusOracle.stream().map(o -> { - try { - return o.create(globalState); - } catch (Exception e1) { - throw new AssertionError(e1); - } - }).collect(Collectors.toList()); - return new CompositeTestOracle(oracles, globalState); + protected TestOracle getTestOracle(PostgresGlobalState globalState) throws SQLException { + List> oracles = ((CitusOptions) globalState + .getDbmsSpecificOptions()).citusOracle.stream().map(o -> { + try { + return o.create(globalState); + } catch (Exception e1) { + throw new AssertionError(e1); + } + }).collect(Collectors.toList()); + return new CompositeTestOracle(oracles, globalState); } private List readCitusWorkerNodes(PostgresGlobalState globalState, SQLConnection con) @@ -370,6 +371,10 @@ private void prepareCitusWorkerNodes(PostgresGlobalState globalState, List pg_backend_pid()"); s.execute("DROP DATABASE IF EXISTS " + databaseName); } try (Statement s = con.createStatement()) { diff --git a/src/sqlancer/citus/CitusSchema.java b/src/sqlancer/citus/CitusSchema.java index defda8804..b2550bdce 100644 --- a/src/sqlancer/citus/CitusSchema.java +++ b/src/sqlancer/citus/CitusSchema.java @@ -58,41 +58,35 @@ public Integer getColocationId() { public static CitusSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { PostgresSchema schema = PostgresSchema.fromConnection(con, databaseName); - try { - List databaseTables = new ArrayList<>(); - try (Statement s = con.createStatement()) { - try (ResultSet rs = s.executeQuery( - "SELECT table_name, column_to_column_name(logicalrelid, partkey) AS dist_col_name, colocationid FROM information_schema.tables LEFT OUTER JOIN pg_dist_partition ON logicalrelid=table_name::regclass WHERE table_schema='public' OR table_schema LIKE 'pg_temp_%';")) { - while (rs.next()) { - String tableName = rs.getString("table_name"); - /* citus_tables is a helper view, we don't need to test with it so we let's ignore it */ - if (tableName.equals("citus_tables")) { - continue; - } - String distributionColumnName = rs.getString("dist_col_name"); - Integer colocationId = rs.getInt("colocationid"); - if (rs.wasNull()) { - colocationId = null; - } - PostgresTable t = schema.getDatabaseTable(tableName); - PostgresColumn distributionColumn = null; - if (t == null) { - continue; - } - if (distributionColumnName != null && !distributionColumnName.equals("")) { - distributionColumn = t.getColumns().stream() - .filter(c -> c.getName().equals(distributionColumnName)) - .collect(Collectors.toList()).get(0); - } - CitusTable tCitus = new CitusTable(t, distributionColumn, colocationId); - databaseTables.add(tCitus); - } + List databaseTables = new ArrayList<>(); + try (Statement s = con.createStatement(); ResultSet rs = s.executeQuery( + "SELECT table_name, column_to_column_name(logicalrelid, partkey) AS dist_col_name, colocationid FROM information_schema.tables LEFT OUTER JOIN pg_dist_partition ON logicalrelid=table_name::regclass WHERE table_schema='public' OR table_schema LIKE 'pg_temp_%';")) { + while (rs.next()) { + String tableName = rs.getString("table_name"); + /* skip Citus-managed views in the public schema (citus_tables, citus_schemas, etc.) */ + if (tableName.startsWith("citus_")) { + continue; } + String distributionColumnName = rs.getString("dist_col_name"); + Integer colocationId = rs.getInt("colocationid"); + if (rs.wasNull()) { + colocationId = null; + } + PostgresTable t = schema.getDatabaseTable(tableName); + PostgresColumn distributionColumn = null; + if (t == null) { + continue; + } + if (distributionColumnName != null && !distributionColumnName.equals("")) { + distributionColumn = t.getColumns().stream().filter(c -> c.getName().equals(distributionColumnName)) + .collect(Collectors.toList()).get(0); + } + CitusTable tCitus = new CitusTable(t, distributionColumn, colocationId); + databaseTables.add(tCitus); } - return new CitusSchema(databaseTables, databaseName); } catch (SQLIntegrityConstraintViolationException e) { throw new AssertionError(e); } + return new CitusSchema(databaseTables, databaseName); } - } diff --git a/src/sqlancer/citus/gen/CitusCommon.java b/src/sqlancer/citus/gen/CitusCommon.java index f4a121efe..58b1b7c16 100644 --- a/src/sqlancer/citus/gen/CitusCommon.java +++ b/src/sqlancer/citus/gen/CitusCommon.java @@ -1,5 +1,8 @@ package sqlancer.citus.gen; +import java.util.ArrayList; +import java.util.List; + import sqlancer.citus.CitusBugs; import sqlancer.common.query.ExpectedErrors; @@ -8,21 +11,24 @@ public final class CitusCommon { private CitusCommon() { } - public static void addCitusErrors(ExpectedErrors errors) { + public static List getCitusErrors() { // not supported by Citus + ArrayList errors = new ArrayList<>(); errors.add("failed to evaluate partition key in insert"); errors.add("cannot perform an INSERT without a partition column value"); errors.add("cannot perform an INSERT with NULL in the partition column"); errors.add("recursive CTEs are not supported in distributed queries"); + errors.add("recursive CTEs are only supported when they contain a filter on the distribution column"); errors.add("could not run distributed query with GROUPING SETS, CUBE, or ROLLUP"); errors.add("Subqueries in HAVING cannot refer to outer query"); errors.add("non-IMMUTABLE functions are not allowed in the RETURNING clause"); errors.add("functions used in UPDATE queries on distributed tables must not be VOLATILE"); errors.add("STABLE functions used in UPDATE queries cannot be called with column references"); - errors.add( - "functions used in the WHERE clause of modification queries on distributed tables must not be VOLATILE"); + errors.add("of modification queries on distributed tables must not be VOLATILE"); errors.add("cannot execute ADD CONSTRAINT command with other subcommands"); errors.add("cannot execute ALTER TABLE command involving partition column"); + errors.add("alter table command is currently unsupported"); + errors.add("on distributed partitioned tables are not supported"); errors.add("could not run distributed query with FOR UPDATE/SHARE commands"); errors.add("is not a regular, foreign or partitioned table"); errors.add("must be a distributed table or a reference table"); @@ -49,7 +55,15 @@ public static void addCitusErrors(ExpectedErrors errors) { errors.add("direct joins between distributed and local tables are not supported"); errors.add("unlogged columnar tables are not supported"); errors.add("UPDATE and CTID scans not supported for ColumnarScan"); - errors.add("indexes not supported for columnar tables"); + errors.add("unsupported access method for the index on columnar table"); + errors.add("BRIN indexes on columnar tables are not supported"); + errors.add("invalid byte sequence for encoding \"UTF8\": 0x00"); + errors.add("columnar_tuple_insert_speculative not implemented"); + errors.add("row field count is 1, expected 2"); + errors.add("incorrect binary data format"); + errors.add("invalid sign in external \"numeric\" value"); + errors.add("Foreign keys and AFTER ROW triggers are not supported for columnar tables"); + errors.addAll(getColumnarOidErrors()); // current errors in Citus (to be removed once fixed) if (CitusBugs.bug3957) { @@ -74,6 +88,23 @@ public static void addCitusErrors(ExpectedErrors errors) { if (CitusBugs.bug4079) { errors.add("aggregate function calls cannot be nested"); } + return errors; } + /** + * Citus can fail with "could not open relation with OID 0" when operating on columnar temporary tables (e.g., USING + * columnar ON COMMIT DROP), during VACUUM, DISCARD TEMPORARY, or INSERT operations where Citus cannot resolve the + * relation OID. + * + * @return the list of expected error substrings for columnar OID resolution failures. + */ + public static List getColumnarOidErrors() { + List errors = new ArrayList<>(); + errors.add("could not open relation with OID 0"); + return errors; + } + + public static void addCitusErrors(ExpectedErrors errors) { + errors.addAll(getCitusErrors()); + } } diff --git a/src/sqlancer/citus/gen/CitusDiscardGenerator.java b/src/sqlancer/citus/gen/CitusDiscardGenerator.java new file mode 100644 index 000000000..f4a1b3240 --- /dev/null +++ b/src/sqlancer/citus/gen/CitusDiscardGenerator.java @@ -0,0 +1,20 @@ +package sqlancer.citus.gen; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.gen.PostgresDiscardGenerator; + +public final class CitusDiscardGenerator { + + private CitusDiscardGenerator() { + } + + public static SQLQueryAdapter create(PostgresGlobalState globalState) { + SQLQueryAdapter discardQuery = PostgresDiscardGenerator.create(globalState); + ExpectedErrors errors = discardQuery.getExpectedErrors(); + CitusCommon.addCitusErrors(errors); + return discardQuery; + } + +} diff --git a/src/sqlancer/citus/gen/CitusInsertGenerator.java b/src/sqlancer/citus/gen/CitusInsertGenerator.java index 269d94bbe..78b5794e7 100644 --- a/src/sqlancer/citus/gen/CitusInsertGenerator.java +++ b/src/sqlancer/citus/gen/CitusInsertGenerator.java @@ -1,5 +1,6 @@ package sqlancer.citus.gen; +import sqlancer.citus.CitusBugs; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.postgres.PostgresGlobalState; @@ -14,6 +15,9 @@ public static SQLQueryAdapter insert(PostgresGlobalState globalState) { SQLQueryAdapter insertQuery = PostgresInsertGenerator.insert(globalState); ExpectedErrors errors = insertQuery.getExpectedErrors(); CitusCommon.addCitusErrors(errors); + if (CitusBugs.bug6298) { + errors.add("columnar_tuple_insert_speculative not implemented"); + } return insertQuery; } diff --git a/src/sqlancer/citus/gen/CitusReindexGenerator.java b/src/sqlancer/citus/gen/CitusReindexGenerator.java new file mode 100644 index 000000000..6f37cbe06 --- /dev/null +++ b/src/sqlancer/citus/gen/CitusReindexGenerator.java @@ -0,0 +1,20 @@ +package sqlancer.citus.gen; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.gen.PostgresReindexGenerator; + +public final class CitusReindexGenerator { + + private CitusReindexGenerator() { + } + + public static SQLQueryAdapter create(PostgresGlobalState globalState) { + SQLQueryAdapter reindexQuery = PostgresReindexGenerator.create(globalState); + ExpectedErrors errors = reindexQuery.getExpectedErrors(); + CitusCommon.addCitusErrors(errors); + return reindexQuery; + } + +} diff --git a/src/sqlancer/citus/gen/CitusTruncateGenerator.java b/src/sqlancer/citus/gen/CitusTruncateGenerator.java new file mode 100644 index 000000000..cf36ce9c2 --- /dev/null +++ b/src/sqlancer/citus/gen/CitusTruncateGenerator.java @@ -0,0 +1,20 @@ +package sqlancer.citus.gen; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.gen.PostgresTruncateGenerator; + +public final class CitusTruncateGenerator { + + private CitusTruncateGenerator() { + } + + public static SQLQueryAdapter create(PostgresGlobalState globalState) { + SQLQueryAdapter truncateQuery = PostgresTruncateGenerator.create(globalState); + ExpectedErrors errors = truncateQuery.getExpectedErrors(); + CitusCommon.addCitusErrors(errors); + return truncateQuery; + } + +} diff --git a/src/sqlancer/citus/gen/CitusVacuumGenerator.java b/src/sqlancer/citus/gen/CitusVacuumGenerator.java new file mode 100644 index 000000000..ae73dbf82 --- /dev/null +++ b/src/sqlancer/citus/gen/CitusVacuumGenerator.java @@ -0,0 +1,20 @@ +package sqlancer.citus.gen; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.gen.PostgresVacuumGenerator; + +public final class CitusVacuumGenerator { + + private CitusVacuumGenerator() { + } + + public static SQLQueryAdapter create(PostgresGlobalState globalState) { + SQLQueryAdapter vacuumQuery = PostgresVacuumGenerator.create(globalState); + ExpectedErrors errors = vacuumQuery.getExpectedErrors(); + CitusCommon.addCitusErrors(errors); + return vacuumQuery; + } + +} diff --git a/src/sqlancer/citus/oracle/CitusNoRECOracle.java b/src/sqlancer/citus/oracle/CitusNoRECOracle.java deleted file mode 100644 index 88ec3391a..000000000 --- a/src/sqlancer/citus/oracle/CitusNoRECOracle.java +++ /dev/null @@ -1,14 +0,0 @@ -package sqlancer.citus.oracle; - -import sqlancer.citus.gen.CitusCommon; -import sqlancer.postgres.PostgresGlobalState; -import sqlancer.postgres.oracle.PostgresNoRECOracle; - -public class CitusNoRECOracle extends PostgresNoRECOracle { - - public CitusNoRECOracle(PostgresGlobalState globalState) { - super(globalState); - CitusCommon.addCitusErrors(errors); - } - -} diff --git a/src/sqlancer/citus/oracle/tlp/CitusTLPWhereOracle.java b/src/sqlancer/citus/oracle/tlp/CitusTLPWhereOracle.java deleted file mode 100644 index 1ccecd0c5..000000000 --- a/src/sqlancer/citus/oracle/tlp/CitusTLPWhereOracle.java +++ /dev/null @@ -1,35 +0,0 @@ -package sqlancer.citus.oracle.tlp; - -import java.sql.SQLException; -import java.util.Arrays; - -import sqlancer.citus.CitusGlobalState; -import sqlancer.citus.gen.CitusCommon; -import sqlancer.postgres.PostgresGlobalState; -import sqlancer.postgres.oracle.tlp.PostgresTLPWhereOracle; - -public class CitusTLPWhereOracle extends PostgresTLPWhereOracle { - - private final CitusTLPBase citusTLPBase; - - public CitusTLPWhereOracle(CitusGlobalState state) { - super(state); - CitusCommon.addCitusErrors(errors); - citusTLPBase = new CitusTLPBase(state); - } - - @Override - public void check() throws SQLException { - state.setAllowedFunctionTypes(Arrays.asList(PostgresGlobalState.IMMUTABLE)); - citusTLPBase.check(); - s = citusTLPBase.getSchema(); - targetTables = citusTLPBase.getTargetTables(); - gen = citusTLPBase.getGenerator(); - select = citusTLPBase.getSelect(); - predicate = citusTLPBase.getPredicate(); - negatedPredicate = citusTLPBase.getNegatedPredicate(); - isNullPredicate = citusTLPBase.getIsNullPredicate(); - whereCheck(); - state.setDefaultAllowedFunctionTypes(); - } -} diff --git a/src/sqlancer/clickhouse/ClickHouseErrors.java b/src/sqlancer/clickhouse/ClickHouseErrors.java index 17993c7e4..09fbe5ea8 100644 --- a/src/sqlancer/clickhouse/ClickHouseErrors.java +++ b/src/sqlancer/clickhouse/ClickHouseErrors.java @@ -1,5 +1,7 @@ package sqlancer.clickhouse; +import java.util.List; + import sqlancer.common.query.ExpectedErrors; public final class ClickHouseErrors { @@ -7,61 +9,58 @@ public final class ClickHouseErrors { private ClickHouseErrors() { } - public static void addExpectedExpressionErrors(ExpectedErrors errors) { - errors.add("Illegal type"); - errors.add("Argument at index 1 for function like must be constant"); - errors.add("Argument at index 1 for function notLike must be constant"); - errors.add("does not return a value of type UInt8"); - errors.add("invalid escape sequence"); - errors.add("invalid character class range"); - errors.add("Memory limit"); - errors.add("There is no supertype for types"); - errors.add("Bad get: has Int64, requested UInt64"); - errors.add("Cannot convert string"); - errors.add("Cannot read floating point value"); - errors.add("Cannot parse infinity."); - errors.add("Attempt to read after eof: while converting"); - errors.add( - "is violated, because it is a constant expression returning 0. It is most likely an error in table definition"); - errors.add("doesn't exist"); // TODO: consecutive test runs can lead to dropped database - errors.add("is not under aggregate function"); - errors.add("Invalid type for filter in"); - errors.add("argument of function"); - errors.add(" is not under aggregate function and not in GROUP BY"); - errors.add("Expected one of: compound identifier, identifier, list of elements (version"); // VALUES () - errors.add("OptimizedRegularExpression: cannot compile re2"); - errors.add("because it is constant but values of constants are different in source and result"); // https://github.com/ClickHouse/ClickHouse/issues/22119 - errors.add("is violated at row 1. Expression:"); // TODO: check constraint on table creation - errors.add("Cannot parse NaN.: while converting"); // https://github.com/ClickHouse/ClickHouse/issues/22710 - errors.add("Cannot parse number with a sign character but without any numeric character"); - errors.add("Cannot parse number with multiple sign (+/-) characters or intermediate sign character"); - errors.add("Function 'like' doesn't support search with non-constant needles in constant haystack"); - errors.add("Positional argument out of bounds"); - } - - public static void addExpressionHavingErrors(ExpectedErrors errors) { - errors.add("Memory limit"); - } - - public static void addQueryErrors(ExpectedErrors errors) { - errors.add("Memory limit"); - } + public static List getExpectedExpressionErrors() { + return List.of("Argument at index 1 for function like must be constant", + "Argument at index 1 for function notLike must be constant", + "Attempt to read after eof: while converting", "Bad get: has Int64, requested UInt64", + "Cannot convert string", "Cannot insert NULL value into a column of type", + "Cannot parse Int32 from String, because value is too short", "Cannot parse NaN.: while converting", // https://github.com/ClickHouse/ClickHouse/issues/22710 + "Cannot parse infinity.", "Cannot parse number with a sign character but without any numeric character", + "Cannot parse number with multiple sign (+/-) characters or intermediate sign character", + "Cannot parse string", "Cannot read floating point value", + "Cyclic aliases: default expression and column type are incompatible", "Directory for table data", + "Directory not empty", "Expected one of: compound identifier, identifier, list of elements (version", // VALUES + // () + "Function 'like' doesn't support search with non-constant needles in constant haystack", "Illegal type", + "Illegal value (aggregate function) for positional argument in GROUP BY", + "Invalid escape sequence at the end of LIKE pattern", "Invalid type for filter in", "Memory limit", + "OptimizedRegularExpression: cannot compile re2", "Partition key cannot contain constants", + "Positional argument out of bounds", "Sampling expression must be present in the primary key", + "Sorting key cannot contain constants", "There is no supertype for types", "argument of function", + "but its arguments considered equal according to constraints", "does not return a value of type UInt8", + "doesn't exist", // TODO: consecutive test runs can lead to dropped database + "in block. There are only columns:", // https://github.com/ClickHouse/ClickHouse/issues/42399 + "invalid character class range", "invalid escape sequence", + "is not under aggregate function and not in GROUP BY", "is not under aggregate function", + "is violated at row 1. Expression:", // TODO: check constraint on table creation + "is violated, because it is a constant expression returning 0. It is most likely an error in table definition", + "there are only columns", "there are columns", "(NOT_FOUND_COLUMN_IN_BLOCK)", "Missing columns", + "Ambiguous column", "Must be one unsigned integer type. (ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER)", + "Floating point partition key is not supported", "Cannot get JOIN keys from JOIN ON section", + "ILLEGAL_DIVISION", "DECIMAL_OVERFLOW", + "Cannot convert out of range floating point value to integer type", + "Unexpected inf or nan to integer conversion", "No such name in Block::erase", // https://github.com/ClickHouse/ClickHouse/issues/42769 + "EMPTY_LIST_OF_COLUMNS_QUERIED", // https://github.com/ClickHouse/ClickHouse/issues/43003 + "EMPTY_LIST_OF_COLUMNS_PASSED", // https://github.com/ClickHouse/ClickHouse/pull/81835 + "cannot get JOIN keys. (INVALID_JOIN_ON_EXPRESSION)", "AMBIGUOUS_IDENTIFIER", "CYCLIC_ALIASES", + "Positional argument numeric constant expression is not representable as", + "Positional argument must be constant with numeric type", " is out of bounds. Expected in range", + "with constants is not supported. (INVALID_JOIN_ON_EXPRESSION)", + "Cannot get JOIN keys from JOIN ON section", "Unexpected inf or nan to integer conversion", + "Cannot determine join keys in", "Unsigned type must not contain", + "Unexpected inf or nan to integer conversion", - public static void addGroupingErrors(ExpectedErrors errors) { - errors.add("Memory limit"); + // The way we generate JOINs we can have ambiguous left table column without + // alias + // We may not count it as an issue, but it makes no sense to add more complex + // AST generation logic + "MULTIPLE_EXPRESSIONS_FOR_ALIAS", "AMBIGUOUS_IDENTIFIER", // https://github.com/ClickHouse/ClickHouse/issues/45389 + "AMBIGUOUS_COLUMN_NAME", // same https://github.com/ClickHouse/ClickHouse/issues/45389 + "No equality condition found in JOIN ON expression", "Cannot parse number with multiple sign"); } - public static void addTableManipulationErrors(ExpectedErrors errors) { - errors.add("Memory limit"); - errors.add("Directory for table data"); - errors.add("Directory not empty"); - errors.add("Partition key cannot contain constants"); - errors.add("Cannot convert string"); - errors.add("argument of function"); - errors.add("Attempt to read after eof: while converting"); - errors.add("Sorting key cannot contain constants"); - errors.add("Sampling expression must be present in the primary key"); - errors.add("Cyclic aliases: default expression and column type are incompatible"); + public static void addExpectedExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpectedExpressionErrors()); } } diff --git a/src/sqlancer/clickhouse/ClickHouseOptions.java b/src/sqlancer/clickhouse/ClickHouseOptions.java index c58697c98..fec8b62a7 100644 --- a/src/sqlancer/clickhouse/ClickHouseOptions.java +++ b/src/sqlancer/clickhouse/ClickHouseOptions.java @@ -1,6 +1,5 @@ package sqlancer.clickhouse; -import java.sql.SQLException; import java.util.Arrays; import java.util.List; @@ -8,15 +7,6 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.clickhouse.ClickHouseOptions.ClickHouseOracleFactory; -import sqlancer.clickhouse.ClickHouseProvider.ClickHouseGlobalState; -import sqlancer.clickhouse.oracle.tlp.ClickHouseTLPAggregateOracle; -import sqlancer.clickhouse.oracle.tlp.ClickHouseTLPDistinctOracle; -import sqlancer.clickhouse.oracle.tlp.ClickHouseTLPGroupByOracle; -import sqlancer.clickhouse.oracle.tlp.ClickHouseTLPHavingOracle; -import sqlancer.clickhouse.oracle.tlp.ClickHouseTLPWhereOracle; -import sqlancer.common.oracle.TestOracle; @Parameters(separators = "=", commandDescription = "ClickHouse (default port: " + ClickHouseOptions.DEFAULT_PORT + ", default host: " + ClickHouseOptions.DEFAULT_HOST + ")") @@ -30,39 +20,8 @@ public class ClickHouseOptions implements DBMSSpecificOptions { - TLPWhere { - @Override - public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { - return new ClickHouseTLPWhereOracle(globalState); - } - }, - TLPDistinct { - @Override - public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { - return new ClickHouseTLPDistinctOracle(globalState); - } - }, - TLPGroupBy { - @Override - public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { - return new ClickHouseTLPGroupByOracle(globalState); - } - }, - TLPAggregate { - @Override - public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { - return new ClickHouseTLPAggregateOracle(globalState); - } - }, - TLPHaving { - @Override - public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { - return new ClickHouseTLPHavingOracle(globalState); - } - }; - - } + @Parameter(names = { "--analyzer" }, description = "Enable analyzer in ClickHouse", arity = 1) + public boolean enableAnalyzer = true; @Override public List getTestOracleFactory() { diff --git a/src/sqlancer/clickhouse/ClickHouseOracleFactory.java b/src/sqlancer/clickhouse/ClickHouseOracleFactory.java new file mode 100644 index 000000000..53b38a6f1 --- /dev/null +++ b/src/sqlancer/clickhouse/ClickHouseOracleFactory.java @@ -0,0 +1,62 @@ +package sqlancer.clickhouse; + +import java.sql.SQLException; + +import sqlancer.OracleFactory; +import sqlancer.clickhouse.ClickHouseProvider.ClickHouseGlobalState; +import sqlancer.clickhouse.gen.ClickHouseExpressionGenerator; +import sqlancer.clickhouse.oracle.tlp.ClickHouseTLPAggregateOracle; +import sqlancer.clickhouse.oracle.tlp.ClickHouseTLPDistinctOracle; +import sqlancer.clickhouse.oracle.tlp.ClickHouseTLPGroupByOracle; +import sqlancer.clickhouse.oracle.tlp.ClickHouseTLPHavingOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; + +public enum ClickHouseOracleFactory implements OracleFactory { + TLPWhere { + @Override + public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { + ClickHouseExpressionGenerator gen = new ClickHouseExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors() + .with(ClickHouseErrors.getExpectedExpressionErrors()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + TLPDistinct { + @Override + public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { + return new ClickHouseTLPDistinctOracle(globalState); + } + }, + TLPGroupBy { + @Override + public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { + return new ClickHouseTLPGroupByOracle(globalState); + } + }, + TLPAggregate { + @Override + public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { + return new ClickHouseTLPAggregateOracle(globalState); + } + }, + TLPHaving { + @Override + public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { + return new ClickHouseTLPHavingOracle(globalState); + } + }, + NoREC { + @Override + public TestOracle create(ClickHouseGlobalState globalState) throws SQLException { + ClickHouseExpressionGenerator gen = new ClickHouseExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(ClickHouseErrors.getExpectedExpressionErrors()) + .with("canceling statement due to statement timeout").build(); + + return new NoRECOracle<>(globalState, gen, errors); + } + } +} diff --git a/src/sqlancer/clickhouse/ClickHouseProvider.java b/src/sqlancer/clickhouse/ClickHouseProvider.java index 7ff2de11b..c611d77d7 100644 --- a/src/sqlancer/clickhouse/ClickHouseProvider.java +++ b/src/sqlancer/clickhouse/ClickHouseProvider.java @@ -87,7 +87,7 @@ protected ClickHouseSchema readSchema() throws SQLException { @Override public void generateDatabase(ClickHouseGlobalState globalState) throws Exception { - for (int i = 0; i < Randomly.fromOptions(1); i++) { + for (int i = 0; i < Randomly.fromOptions(1, 2, 3, 4, 5); i++) { boolean success; do { String tableName = ClickHouseCommon.createTableName(i); @@ -96,6 +96,7 @@ public void generateDatabase(ClickHouseGlobalState globalState) throws Exception } while (!success); } + // TODO: add more Actions to populate table StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), ClickHouseProvider::mapActions, (q) -> { if (globalState.getSchema().getDatabaseTables().isEmpty()) { @@ -126,6 +127,8 @@ public SQLConnection createDatabase(ClickHouseGlobalState globalState) throws SQ globalState.getState().logStatement(dropDatabaseCommand); String createDatabaseCommand = "CREATE DATABASE IF NOT EXISTS " + databaseName; globalState.getState().logStatement(createDatabaseCommand); + String useDatabaseCommand = "USE " + databaseName; // Noop. To reproduce easier. + globalState.getState().logStatement(useDatabaseCommand); try (Statement s = con.createStatement()) { s.execute(dropDatabaseCommand); Thread.sleep(1000); @@ -139,7 +142,9 @@ public SQLConnection createDatabase(ClickHouseGlobalState globalState) throws SQ e.printStackTrace(); } con.close(); - con = DriverManager.getConnection(String.format("jdbc:clickhouse://%s:%d/%s", host, port, databaseName), + con = DriverManager.getConnection( + String.format("jdbc:clickhouse://%s:%d/%s?socket_timeout=300000%s", host, port, databaseName, + clickHouseOptions.enableAnalyzer ? "&allow_experimental_analyzer=1" : ""), globalState.getOptions().getUserName(), globalState.getOptions().getPassword()); return new SQLConnection(con); } diff --git a/src/sqlancer/clickhouse/ClickHouseSchema.java b/src/sqlancer/clickhouse/ClickHouseSchema.java index 21b39e021..8f8f906ec 100644 --- a/src/sqlancer/clickhouse/ClickHouseSchema.java +++ b/src/sqlancer/clickhouse/ClickHouseSchema.java @@ -8,12 +8,15 @@ import java.util.List; import java.util.Map; -import ru.yandex.clickhouse.domain.ClickHouseDataType; +import com.clickhouse.client.ClickHouseDataType; + import sqlancer.Randomly; import sqlancer.SQLConnection; import sqlancer.clickhouse.ClickHouseProvider.ClickHouseGlobalState; import sqlancer.clickhouse.ClickHouseSchema.ClickHouseTable; +import sqlancer.clickhouse.ast.ClickHouseColumnReference; import sqlancer.clickhouse.ast.ClickHouseConstant; +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; import sqlancer.common.schema.AbstractRelationalTable; import sqlancer.common.schema.AbstractRowValue; import sqlancer.common.schema.AbstractSchema; @@ -34,7 +37,7 @@ public ClickHouseLancerDataType(ClickHouseDataType type) { } public ClickHouseLancerDataType(String textRepr) { - this.clickHouseType = ClickHouseDataType.fromTypeString(textRepr); + this.clickHouseType = ClickHouseDataType.of(textRepr); this.textRepr = textRepr; } @@ -60,14 +63,15 @@ public static class ClickHouseColumn extends AbstractTableColumn databaseColumns = getTableColumns(con, tableName); List indexes = Collections.emptyList(); - boolean isView = tableName.startsWith("v"); + boolean isView = matchesViewName(tableName); ClickHouseTable t = new ClickHouseTable(tableName, databaseColumns, indexes, isView); for (ClickHouseColumn c : databaseColumns) { c.setTable(t); @@ -217,7 +250,7 @@ private static List getTableColumns(SQLConnection con, String boolean isAlias = "ALIAS".compareTo(defaultType) == 0; boolean isMaterialized = "MATERIALIZED".compareTo(defaultType) == 0; ClickHouseColumn c = new ClickHouseColumn(columnName, getColumnType(dataType), isAlias, - isMaterialized); + isMaterialized, null); columns.add(c); } } diff --git a/src/sqlancer/clickhouse/ClickHouseToStringVisitor.java b/src/sqlancer/clickhouse/ClickHouseToStringVisitor.java index beaa37b49..29bcddcb7 100644 --- a/src/sqlancer/clickhouse/ClickHouseToStringVisitor.java +++ b/src/sqlancer/clickhouse/ClickHouseToStringVisitor.java @@ -1,7 +1,10 @@ package sqlancer.clickhouse; +import java.util.List; + import sqlancer.clickhouse.ast.ClickHouseAggregate; -import sqlancer.clickhouse.ast.ClickHouseBinaryComparisonOperation; +import sqlancer.clickhouse.ast.ClickHouseAliasOperation; +import sqlancer.clickhouse.ast.ClickHouseBinaryFunctionOperation; import sqlancer.clickhouse.ast.ClickHouseBinaryLogicalOperation; import sqlancer.clickhouse.ast.ClickHouseCastOperation; import sqlancer.clickhouse.ast.ClickHouseColumnReference; @@ -20,17 +23,6 @@ public void visitSpecific(ClickHouseExpression expr) { ClickHouseVisitor.super.visit(expr); } - @Override - public void visit(ClickHouseBinaryComparisonOperation op) { - sb.append("("); - visit(op.getLeft()); - sb.append(") "); - sb.append(op.getOperator().getTextRepresentation()); - sb.append(" ("); - visit(op.getRight()); - sb.append(")"); - } - @Override public void visit(ClickHouseBinaryLogicalOperation op) { sb.append("("); @@ -82,8 +74,17 @@ public void visit(ClickHouseSelect select, boolean inner) { } visit(select.getFetchColumns()); - sb.append(" FROM "); - visit(select.getFromList()); + List fromList = select.getFromList(); + if (fromList != null) { + sb.append(" FROM "); + visit(fromList); + } + List joins = select.getJoinClauses(); + if (!joins.isEmpty()) { + for (ClickHouseExpression.ClickHouseJoin join : joins) { + visit(join); + } + } if (select.getWhereClause() != null) { sb.append(" WHERE "); visit(select.getWhereClause()); @@ -96,9 +97,9 @@ public void visit(ClickHouseSelect select, boolean inner) { sb.append(" HAVING "); visit(select.getHavingClause()); } - if (!select.getOrderByClause().isEmpty()) { + if (!select.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - visit(select.getOrderByClause()); + visit(select.getOrderByClauses()); } if (inner) { sb.append(")"); @@ -107,7 +108,12 @@ public void visit(ClickHouseSelect select, boolean inner) { @Override public void visit(ClickHouseTableReference tableReference) { - sb.append(tableReference.getTable().getName()); + sb.append(tableReference.getTable().getName()); // Original name, not alias. + String alias = tableReference.getAlias(); + if (alias != null) { + sb.append(" AS " + alias); + } + } @Override @@ -129,16 +135,70 @@ public void visit(ClickHouseCastOperation cast) { @Override public void visit(ClickHouseExpression.ClickHouseJoin join) { - + ClickHouseExpression.ClickHouseJoin.JoinType type = join.getType(); + if (type == ClickHouseExpression.ClickHouseJoin.JoinType.CROSS) { + sb.append(" JOIN "); + visit(join.getRightTable()); + } else if (type == ClickHouseExpression.ClickHouseJoin.JoinType.INNER) { + sb.append(" INNER JOIN "); + visit(join.getRightTable()); + } else if (type == ClickHouseExpression.ClickHouseJoin.JoinType.LEFT_OUTER) { + sb.append(" LEFT OUTER JOIN "); + visit(join.getRightTable()); + } else if (type == ClickHouseExpression.ClickHouseJoin.JoinType.RIGHT_OUTER) { + sb.append(" RIGHT OUTER JOIN "); + visit(join.getRightTable()); + } else if (type == ClickHouseExpression.ClickHouseJoin.JoinType.FULL_OUTER) { + sb.append(" FULL OUTER JOIN "); + visit(join.getRightTable()); + } else if (type == ClickHouseExpression.ClickHouseJoin.JoinType.LEFT_ANTI) { + sb.append(" LEFT ANTI JOIN "); + visit(join.getRightTable()); + } else if (type == ClickHouseExpression.ClickHouseJoin.JoinType.RIGHT_ANTI) { + sb.append(" RIGHT ANTI JOIN "); + visit(join.getRightTable()); + } else { + throw new UnsupportedOperationException(); + } + ClickHouseExpression onClause = join.getOnClause(); + if (onClause != null) { + sb.append(" ON "); + visit(onClause); + } } @Override public void visit(ClickHouseColumnReference c) { - if (c.getColumn().getTable() == null) { + if (c.getTableAlias() != null) { + sb.append(c.getTableAlias()); + sb.append("."); + sb.append(c.getColumn().getName()); + } else if (c.getColumn().getTable() == null) { sb.append(c.getColumn().getName()); } else { sb.append(c.getColumn().getFullQualifiedName()); } + if (c.getAlias() != null) { + sb.append(" AS " + c.getAlias()); + } + } + + @Override + public void visit(ClickHouseBinaryFunctionOperation func) { + sb.append(func.getOperatorRepresentation()); + sb.append("("); + visit(func.getLeft()); + sb.append(","); + visit(func.getRight()); + sb.append(")"); + } + + @Override + public void visit(ClickHouseAliasOperation alias) { + visit(alias.getExpression()); + sb.append(" AS `"); + sb.append(alias.getAlias()); + sb.append("`"); } public static String asString(ClickHouseExpression expr) { diff --git a/src/sqlancer/clickhouse/ClickHouseVisitor.java b/src/sqlancer/clickhouse/ClickHouseVisitor.java index e1473f669..2966f93be 100644 --- a/src/sqlancer/clickhouse/ClickHouseVisitor.java +++ b/src/sqlancer/clickhouse/ClickHouseVisitor.java @@ -1,7 +1,9 @@ package sqlancer.clickhouse; import sqlancer.clickhouse.ast.ClickHouseAggregate; +import sqlancer.clickhouse.ast.ClickHouseAliasOperation; import sqlancer.clickhouse.ast.ClickHouseBinaryComparisonOperation; +import sqlancer.clickhouse.ast.ClickHouseBinaryFunctionOperation; import sqlancer.clickhouse.ast.ClickHouseBinaryLogicalOperation; import sqlancer.clickhouse.ast.ClickHouseCastOperation; import sqlancer.clickhouse.ast.ClickHouseColumnReference; @@ -51,12 +53,18 @@ default void visit(ClickHouseExpression.ClickHousePostfixText op) { void visit(ClickHouseCastOperation cast); + void visit(ClickHouseAliasOperation alias); + void visit(ClickHouseExpression.ClickHouseJoin join); void visit(ClickHouseAggregate aggregate); + void visit(ClickHouseBinaryFunctionOperation func); + default void visit(ClickHouseExpression expr) { - if (expr instanceof ClickHouseBinaryComparisonOperation) { + if (expr instanceof ClickHouseBinaryFunctionOperation) { + visit((ClickHouseBinaryFunctionOperation) expr); + } else if (expr instanceof ClickHouseBinaryComparisonOperation) { visit((ClickHouseBinaryComparisonOperation) expr); } else if (expr instanceof ClickHouseBinaryLogicalOperation) { visit((ClickHouseBinaryLogicalOperation) expr); @@ -78,6 +86,10 @@ default void visit(ClickHouseExpression expr) { visit((ClickHouseExpression.ClickHousePostfixText) expr); } else if (expr instanceof ClickHouseAggregate) { visit((ClickHouseAggregate) expr); + } else if (expr instanceof ClickHouseAliasOperation) { + visit((ClickHouseAliasOperation) expr); + } else if (expr instanceof ClickHouseExpression.ClickHouseJoinOnClause) { + visit((ClickHouseExpression.ClickHouseJoinOnClause) expr); } else { throw new AssertionError(expr); } diff --git a/src/sqlancer/clickhouse/ast/ClickHouseAggregate.java b/src/sqlancer/clickhouse/ast/ClickHouseAggregate.java index 172554b77..a0483bd57 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseAggregate.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseAggregate.java @@ -4,20 +4,20 @@ import java.util.List; import java.util.stream.Collectors; -import ru.yandex.clickhouse.domain.ClickHouseDataType; +import com.clickhouse.client.ClickHouseDataType; + import sqlancer.Randomly; import sqlancer.clickhouse.ClickHouseSchema; public class ClickHouseAggregate extends ClickHouseExpression { private final ClickHouseAggregate.ClickHouseAggregateFunction func; - private final List expr; + private final ClickHouseExpression expr; public enum ClickHouseAggregateFunction { AVG(ClickHouseDataType.Int8, ClickHouseDataType.Int16, ClickHouseDataType.Int32, ClickHouseDataType.Int64, ClickHouseDataType.UInt8, ClickHouseDataType.UInt16, ClickHouseDataType.UInt32, ClickHouseDataType.UInt64, ClickHouseDataType.Float32, ClickHouseDataType.Float64), - BOOL_AND(ClickHouseDataType.UInt8), BOOL_OR(ClickHouseDataType.UInt8), COUNT(ClickHouseDataType.Int8, ClickHouseDataType.Int16, ClickHouseDataType.Int32, ClickHouseDataType.Int64, ClickHouseDataType.UInt8, ClickHouseDataType.UInt16, ClickHouseDataType.UInt32, ClickHouseDataType.UInt64, ClickHouseDataType.Float32, ClickHouseDataType.Float64, @@ -41,8 +41,8 @@ public static ClickHouseAggregateFunction getRandom(ClickHouseDataType type) { return Randomly.fromOptions(values()); } - public List getTypes(ClickHouseDataType returnType) { - return Arrays.asList(returnType); + public ClickHouseDataType getType(ClickHouseDataType returnType) { + return returnType; } public boolean supportsReturnType(ClickHouseDataType returnType) { @@ -65,7 +65,7 @@ public ClickHouseSchema.ClickHouseLancerDataType getRandomReturnType() { } - public ClickHouseAggregate(List expr, ClickHouseAggregateFunction func) { + public ClickHouseAggregate(ClickHouseExpression expr, ClickHouseAggregateFunction func) { this.expr = expr; this.func = func; } @@ -74,7 +74,7 @@ public ClickHouseAggregate.ClickHouseAggregateFunction getFunc() { return func; } - public List getExpr() { + public ClickHouseExpression getExpr() { return expr; } diff --git a/src/sqlancer/clickhouse/ast/ClickHouseAliasOperation.java b/src/sqlancer/clickhouse/ast/ClickHouseAliasOperation.java new file mode 100644 index 000000000..b7976ff46 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/ClickHouseAliasOperation.java @@ -0,0 +1,28 @@ +package sqlancer.clickhouse.ast; + +public class ClickHouseAliasOperation extends ClickHouseExpression { + + private final ClickHouseExpression expression; + private final String alias; + + public ClickHouseAliasOperation(ClickHouseExpression expression, String alias) { + if (expression == null) { + throw new AssertionError(); + } + this.expression = expression; + this.alias = alias; + } + + @Override + public ClickHouseConstant getExpectedValue() { + return expression.getExpectedValue(); + } + + public ClickHouseExpression getExpression() { + return expression; + } + + public String getAlias() { + return alias; + } +} diff --git a/src/sqlancer/clickhouse/ast/ClickHouseBinaryArithmeticOperation.java b/src/sqlancer/clickhouse/ast/ClickHouseBinaryArithmeticOperation.java new file mode 100644 index 000000000..fc0f8504a --- /dev/null +++ b/src/sqlancer/clickhouse/ast/ClickHouseBinaryArithmeticOperation.java @@ -0,0 +1,65 @@ +package sqlancer.clickhouse.ast; + +import sqlancer.Randomly; +import sqlancer.common.visitor.BinaryOperation; + +public class ClickHouseBinaryArithmeticOperation extends ClickHouseExpression + implements BinaryOperation { + + public enum ClickHouseBinaryArithmeticOperator { + ADD("+"), // + MINUS("-"), // + MULT("*"), // + DIV("/"), // + MODULO("%"); // + + String textRepresentation; + + ClickHouseBinaryArithmeticOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static ClickHouseBinaryArithmeticOperator getRandom() { + return Randomly.fromOptions(values()); + } + + public String getTextRepresentation() { + return textRepresentation; + } + } + + private final ClickHouseBinaryArithmeticOperation.ClickHouseBinaryArithmeticOperator operation; + private final ClickHouseExpression left; + private final ClickHouseExpression right; + + public ClickHouseBinaryArithmeticOperation(ClickHouseExpression left, ClickHouseExpression right, + ClickHouseBinaryArithmeticOperation.ClickHouseBinaryArithmeticOperator operation) { + this.left = left; + this.right = right; + this.operation = operation; + } + + public ClickHouseBinaryArithmeticOperation.ClickHouseBinaryArithmeticOperator getOperator() { + return operation; + } + + @Override + public ClickHouseExpression getLeft() { + return left; + } + + @Override + public ClickHouseExpression getRight() { + return right; + } + + @Override + public String getOperatorRepresentation() { + return operation.getTextRepresentation(); + } + + public static ClickHouseBinaryArithmeticOperation create(ClickHouseExpression left, ClickHouseExpression right, + ClickHouseBinaryArithmeticOperation.ClickHouseBinaryArithmeticOperator op) { + return new ClickHouseBinaryArithmeticOperation(left, right, op); + } +} diff --git a/src/sqlancer/clickhouse/ast/ClickHouseBinaryComparisonOperation.java b/src/sqlancer/clickhouse/ast/ClickHouseBinaryComparisonOperation.java index 6696a3ee3..75dc259e8 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseBinaryComparisonOperation.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseBinaryComparisonOperation.java @@ -1,8 +1,10 @@ package sqlancer.clickhouse.ast; -import ru.yandex.clickhouse.domain.ClickHouseDataType; +import com.clickhouse.client.ClickHouseDataType; + import sqlancer.LikeImplementationHelper; import sqlancer.Randomly; +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; import sqlancer.common.visitor.BinaryOperation; public class ClickHouseBinaryComparisonOperation extends ClickHouseExpression @@ -75,7 +77,7 @@ ClickHouseConstant apply(ClickHouseConstant left, ClickHouseConstant right) { } else if (lessThan.asInt() >= 1) { return lessThan; } else { - return ClickHouseConstant.createFalse(); + return ClickHouseCreateConstant.createFalse(); } } } @@ -95,7 +97,7 @@ ClickHouseConstant apply(ClickHouseConstant left, ClickHouseConstant right) { && equals.getDataType() == ClickHouseDataType.UInt32 && equals.getDataType() == ClickHouseDataType.Int64 && equals.getDataType() == ClickHouseDataType.UInt64 && equals.asInt() == 1) { - return ClickHouseConstant.createFalse(); + return ClickHouseCreateConstant.createFalse(); } else { ClickHouseConstant applyLess = left.applyLess(right); if (applyLess == null) { @@ -121,7 +123,7 @@ ClickHouseConstant apply(ClickHouseConstant left, ClickHouseConstant right) { && lessThan.getDataType() == ClickHouseDataType.UInt32 && lessThan.getDataType() == ClickHouseDataType.Int64 && lessThan.getDataType() == ClickHouseDataType.UInt64 && lessThan.asInt() >= 1) { - return ClickHouseConstant.createTrue(); + return ClickHouseCreateConstant.createTrue(); } else { ClickHouseConstant applyLess = left.applyLess(right); if (applyLess == null) { @@ -146,14 +148,14 @@ ClickHouseConstant apply(ClickHouseConstant left, ClickHouseConstant right) { return null; } if (left.isNull() || right.isNull()) { - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); } else { ClickHouseConstant applyEquals = left.applyEquals(right); if (applyEquals == null) { return null; } boolean equals = applyEquals.asInt() == 1; - return ClickHouseConstant.createBoolean(!equals); + return ClickHouseCreateConstant.createBoolean(!equals); } } @@ -165,7 +167,7 @@ ClickHouseConstant apply(ClickHouseConstant left, ClickHouseConstant right) { return null; } if (left.isNull() || right.isNull()) { - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); } ClickHouseConstant leftStr = ClickHouseCast.castToText(left); ClickHouseConstant rightStr = ClickHouseCast.castToText(right); @@ -173,7 +175,7 @@ ClickHouseConstant apply(ClickHouseConstant left, ClickHouseConstant right) { return null; } boolean val = LikeImplementationHelper.match(leftStr.asString(), rightStr.asString(), 0, 0, false); - return ClickHouseConstant.createBoolean(val); + return ClickHouseCreateConstant.createBoolean(val); } }; diff --git a/src/sqlancer/clickhouse/ast/ClickHouseBinaryFunctionOperation.java b/src/sqlancer/clickhouse/ast/ClickHouseBinaryFunctionOperation.java new file mode 100644 index 000000000..65bf4c79e --- /dev/null +++ b/src/sqlancer/clickhouse/ast/ClickHouseBinaryFunctionOperation.java @@ -0,0 +1,57 @@ +package sqlancer.clickhouse.ast; + +import sqlancer.Randomly; + +public class ClickHouseBinaryFunctionOperation extends ClickHouseExpression { + + public enum ClickHouseBinaryFunctionOperator { + INT_DIV("intDiv"), GCD("gcd"), LCM("lcm"), MAX2("max2"), MIN2("min2"), POW("pow"); + + String textRepresentation; + + ClickHouseBinaryFunctionOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static ClickHouseBinaryFunctionOperator getRandom() { + return Randomly.fromOptions(values()); + } + + public String getTextRepresentation() { + return textRepresentation; + } + } + + private final ClickHouseBinaryFunctionOperator operation; + private final ClickHouseExpression left; + private final ClickHouseExpression right; + + public ClickHouseBinaryFunctionOperation(ClickHouseExpression left, ClickHouseExpression right, + ClickHouseBinaryFunctionOperator operation) { + this.left = left; + this.right = right; + this.operation = operation; + } + + public ClickHouseBinaryFunctionOperator getOperator() { + return operation; + } + + public ClickHouseExpression getLeft() { + return left; + } + + public ClickHouseExpression getRight() { + return right; + } + + public String getOperatorRepresentation() { + return operation.getTextRepresentation(); + } + + public static ClickHouseBinaryFunctionOperation create(ClickHouseExpression left, ClickHouseExpression right, + ClickHouseBinaryFunctionOperator op) { + return new ClickHouseBinaryFunctionOperation(left, right, op); + } + +} diff --git a/src/sqlancer/clickhouse/ast/ClickHouseBinaryLogicalOperation.java b/src/sqlancer/clickhouse/ast/ClickHouseBinaryLogicalOperation.java index bbd37e865..019fad752 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseBinaryLogicalOperation.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseBinaryLogicalOperation.java @@ -1,6 +1,7 @@ package sqlancer.clickhouse.ast; import sqlancer.Randomly; +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; public class ClickHouseBinaryLogicalOperation extends ClickHouseExpression { @@ -14,22 +15,22 @@ public enum ClickHouseBinaryLogicalOperator { @Override public ClickHouseConstant apply(ClickHouseConstant left, ClickHouseConstant right) { if (left.isNull() && right.isNull()) { - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); } else if (left.isNull()) { if (right.asBooleanNotNull()) { - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); } else { - return ClickHouseConstant.createFalse(); + return ClickHouseCreateConstant.createFalse(); } } else if (right.isNull()) { if (left.asBooleanNotNull()) { - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); } else { - return ClickHouseConstant.createFalse(); + return ClickHouseCreateConstant.createFalse(); } } else { - return left.asBooleanNotNull() && right.asBooleanNotNull() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); + return left.asBooleanNotNull() && right.asBooleanNotNull() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); } } }, @@ -37,13 +38,13 @@ public ClickHouseConstant apply(ClickHouseConstant left, ClickHouseConstant righ @Override public ClickHouseConstant apply(ClickHouseConstant left, ClickHouseConstant right) { if (!left.isNull() && left.asBooleanNotNull()) { - return ClickHouseConstant.createTrue(); + return ClickHouseCreateConstant.createTrue(); } else if (!right.isNull() && right.asBooleanNotNull()) { - return ClickHouseConstant.createTrue(); + return ClickHouseCreateConstant.createTrue(); } else if (left.isNull() || right.isNull()) { - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); } else { - return ClickHouseConstant.createFalse(); + return ClickHouseCreateConstant.createFalse(); } } }; diff --git a/src/sqlancer/clickhouse/ast/ClickHouseCast.java b/src/sqlancer/clickhouse/ast/ClickHouseCast.java index dc35f1fc7..09c946717 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseCast.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseCast.java @@ -6,7 +6,9 @@ import java.util.Optional; import java.util.regex.Pattern; -import ru.yandex.clickhouse.domain.ClickHouseDataType; +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; public final class ClickHouseCast extends ClickHouseExpression { @@ -50,18 +52,18 @@ public static Optional isTrue(ClickHouseConstant value) { public static ClickHouseConstant castToInt(ClickHouseConstant cons) { switch (cons.getDataType()) { case Nothing: - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); case Int32: return cons; case Float64: - return ClickHouseConstant.createInt32Constant((long) cons.asDouble()); + return ClickHouseCreateConstant.createInt32Constant((long) cons.asDouble()); case String: String asString = cons.asString(); while (startsWithWhitespace(asString)) { asString = asString.substring(1); } if (!asString.isEmpty() && unprintAbleCharThatLetsBecomeNumberZero(asString)) { - return ClickHouseConstant.createInt32Constant(0); + return ClickHouseCreateConstant.createInt32Constant(0); } for (int i = asString.length(); i >= 0; i--) { try { @@ -79,13 +81,13 @@ public static ClickHouseConstant castToInt(ClickHouseConstant cons) { result = Long.MAX_VALUE; } } - return ClickHouseConstant.createInt32Constant(result); + return ClickHouseCreateConstant.createInt32Constant(result); } } catch (Exception e) { } } - return ClickHouseConstant.createInt32Constant(0); + return ClickHouseCreateConstant.createInt32Constant(0); default: throw new AssertionError(); } @@ -95,7 +97,7 @@ public static ClickHouseConstant castToInt(ClickHouseConstant cons) { public static ClickHouseConstant castToReal(ClickHouseConstant cons) { ClickHouseConstant numericValue = castToNumeric(cons); if (numericValue.getDataType() == ClickHouseDataType.Int32) { - return ClickHouseConstant.createFloat64Constant(numericValue.asInt()); + return ClickHouseCreateConstant.createFloat64Constant(numericValue.asInt()); } else { return numericValue; } @@ -120,7 +122,7 @@ private static ClickHouseConstant convertInternal(ClickHouseConstant value, bool boolean noNumIsRealZero, boolean convertIntToReal) throws AssertionError { switch (value.getDataType()) { case Nothing: - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); case Int32: case Float64: return value; @@ -130,11 +132,11 @@ private static ClickHouseConstant convertInternal(ClickHouseConstant value, bool asString = asString.substring(1); } if (!asString.isEmpty() && unprintAbleCharThatLetsBecomeNumberZero(asString)) { - return ClickHouseConstant.createInt32Constant(0); + return ClickHouseCreateConstant.createInt32Constant(0); } if (asString.toLowerCase().startsWith("-infinity") || asString.toLowerCase().startsWith("infinity") || asString.startsWith("NaN")) { - return ClickHouseConstant.createInt32Constant(0); + return ClickHouseCreateConstant.createInt32Constant(0); } for (int i = asString.length(); i >= 0; i--) { try { @@ -151,17 +153,17 @@ private static ClickHouseConstant convertInternal(ClickHouseConstant value, bool boolean isInteger = !isFloatingPointNumber && first.compareTo(second) == 0; if (doubleShouldBeConvertedToInt || isInteger && !convertIntToReal) { // see https://www.sqlite.org/src/tktview/afdc5a29dc - return ClickHouseConstant.createInt32Constant(first.longValue()); + return ClickHouseCreateConstant.createInt32Constant(first.longValue()); } else { - return ClickHouseConstant.createFloat64Constant(d); + return ClickHouseCreateConstant.createFloat64Constant(d); } } catch (Exception e) { } } if (noNumIsRealZero) { - return ClickHouseConstant.createFloat64Constant(0.0); + return ClickHouseCreateConstant.createFloat64Constant(0.0); } else { - return ClickHouseConstant.createInt32Constant(0); + return ClickHouseCreateConstant.createInt32Constant(0); } default: throw new AssertionError(value); @@ -222,14 +224,14 @@ public static ClickHouseConstant castToText(ClickHouseConstant cons) { } if (cons.getDataType() == ClickHouseDataType.Float64) { if (cons.asDouble() == Double.POSITIVE_INFINITY) { - return ClickHouseConstant.createStringConstant("Inf"); + return ClickHouseCreateConstant.createStringConstant("Inf"); } else if (cons.asDouble() == Double.NEGATIVE_INFINITY) { - return ClickHouseConstant.createStringConstant("-Inf"); + return ClickHouseCreateConstant.createStringConstant("-Inf"); } return castRealToText(cons); } if (cons.getDataType() == ClickHouseDataType.Int32) { - return ClickHouseConstant.createStringConstant(String.valueOf(cons.asInt())); + return ClickHouseCreateConstant.createStringConstant(String.valueOf(cons.asInt())); } return null; } @@ -237,7 +239,7 @@ public static ClickHouseConstant castToText(ClickHouseConstant cons) { private static synchronized ClickHouseConstant castRealToText(ClickHouseConstant cons) throws AssertionError { try (Statement s = castDatabase.createStatement()) { String castResult = s.executeQuery("SELECT CAST(" + cons.asDouble() + " AS TEXT)").getString(1); - return ClickHouseConstant.createStringConstant(castResult); + return ClickHouseCreateConstant.createStringConstant(castResult); } catch (Exception e) { throw new AssertionError(e); } @@ -246,9 +248,9 @@ private static synchronized ClickHouseConstant castRealToText(ClickHouseConstant public static ClickHouseConstant asBoolean(ClickHouseConstant val) { Optional boolVal = isTrue(val); if (boolVal.isPresent()) { - return ClickHouseConstant.createBoolean(boolVal.get()); + return ClickHouseCreateConstant.createBoolean(boolVal.get()); } else { - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); } } diff --git a/src/sqlancer/clickhouse/ast/ClickHouseCastOperation.java b/src/sqlancer/clickhouse/ast/ClickHouseCastOperation.java index 490a77d37..a905b4e49 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseCastOperation.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseCastOperation.java @@ -1,6 +1,7 @@ package sqlancer.clickhouse.ast; -import ru.yandex.clickhouse.domain.ClickHouseDataType; +import com.clickhouse.client.ClickHouseDataType; + import sqlancer.clickhouse.ClickHouseSchema.ClickHouseLancerDataType; public class ClickHouseCastOperation extends ClickHouseExpression { diff --git a/src/sqlancer/clickhouse/ast/ClickHouseColumnReference.java b/src/sqlancer/clickhouse/ast/ClickHouseColumnReference.java index 818da0e7e..6d75cf5fa 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseColumnReference.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseColumnReference.java @@ -5,28 +5,30 @@ public class ClickHouseColumnReference extends ClickHouseExpression { private final ClickHouseColumn column; - private final ClickHouseConstant value; + private final String columnAlias; + private final String tableAlias; - public ClickHouseColumnReference(ClickHouseColumn column, ClickHouseConstant value) { + public ClickHouseColumnReference(ClickHouseColumn column, String columnAlias, String tableAlias) { this.column = column; - this.value = value; + this.columnAlias = columnAlias; + this.tableAlias = tableAlias; } - public static ClickHouseColumnReference create(ClickHouseColumn column, ClickHouseConstant value) { - return new ClickHouseColumnReference(column, value); + public ClickHouseColumnReference(ClickHouseAliasOperation alias) { + this.column = new ClickHouseColumn(alias.getAlias(), null, true, false, null); + this.columnAlias = null; + this.tableAlias = null; } public ClickHouseColumn getColumn() { return column; } - public ClickHouseConstant getValue() { - return value; + public String getAlias() { + return columnAlias; } - @Override - public ClickHouseConstant getExpectedValue() { - return value; + public String getTableAlias() { + return tableAlias; } - } diff --git a/src/sqlancer/clickhouse/ast/ClickHouseConstant.java b/src/sqlancer/clickhouse/ast/ClickHouseConstant.java index 048c14296..b38ad7e89 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseConstant.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseConstant.java @@ -1,1899 +1,13 @@ package sqlancer.clickhouse.ast; -import java.math.BigInteger; +import com.clickhouse.client.ClickHouseDataType; -import ru.yandex.clickhouse.domain.ClickHouseDataType; -import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; public abstract class ClickHouseConstant extends ClickHouseExpression { - public static class ClickHouseNullConstant extends ClickHouseConstant { - - @Override - public String toString() { - return "NULL"; - } - - @Override - public boolean isNull() { - return true; - } - - @Override - public boolean asBooleanNotNull() { - throw new AssertionError(); - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.Nothing; - } - - @Override - public boolean compareInternal(Object value) { - return false; - } - - @Override - public ClickHouseConstant applyEquals(ClickHouseConstant right) { - return ClickHouseConstant.createNullConstant(); - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - return ClickHouseConstant.createNullConstant(); - } - - @Override - public Object getValue() { - return null; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - return null; - } - } - - public static class ClickHouseUInt8Constant extends ClickHouseConstant { - - private final int value; - - public ClickHouseUInt8Constant(int value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != 0; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.UInt8; - } - - @Override - public boolean compareInternal(Object val) { - return value == (int) val; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value; - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(value); - case Int8: - return ClickHouseConstant.createInt8Constant(value); - case UInt16: - return ClickHouseConstant.createUInt16Constant(value); - case Int16: - return ClickHouseConstant.createInt16Constant(value); - case UInt32: - return ClickHouseConstant.createUInt32Constant(value); - case Int32: - return ClickHouseConstant.createInt32Constant(value); - case UInt64: - return ClickHouseConstant.createUInt64Constant(BigInteger.valueOf(value)); - case Int64: - return ClickHouseConstant.createInt64Constant(BigInteger.valueOf(value)); - case UInt128: - return ClickHouseConstant.createUInt128Constant(BigInteger.valueOf(value)); - case Int128: - return ClickHouseConstant.createInt128Constant(BigInteger.valueOf(value)); - case UInt256: - return ClickHouseConstant.createUInt256Constant(BigInteger.valueOf(value)); - case Int256: - return ClickHouseConstant.createInt256Constant(BigInteger.valueOf(value)); - case Float32: - return ClickHouseConstant.createFloat32Constant((float) value); - case Float64: - return ClickHouseConstant.createFloat64Constant((double) value); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseInt8Constant extends ClickHouseConstant { - - private final int value; - - public ClickHouseInt8Constant(int value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != 0; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.Int8; - } - - @Override - public boolean compareInternal(Object val) { - return value == (int) val; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value; - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(value); - case Int8: - return ClickHouseConstant.createInt8Constant(value); - case UInt16: - return ClickHouseConstant.createUInt16Constant(value); - case Int16: - return ClickHouseConstant.createInt16Constant(value); - case UInt32: - return ClickHouseConstant.createUInt32Constant(value); - case Int32: - return ClickHouseConstant.createInt32Constant(value); - case UInt64: - return ClickHouseConstant.createUInt64Constant(BigInteger.valueOf(value)); - case Int64: - return ClickHouseConstant.createInt64Constant(BigInteger.valueOf(value)); - case Float32: - return ClickHouseConstant.createFloat32Constant((float) value); - case Float64: - return ClickHouseConstant.createFloat64Constant(value); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseUInt16Constant extends ClickHouseConstant { - - private final long value; - - public ClickHouseUInt16Constant(long value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != 0; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.UInt16; - } - - @Override - public boolean compareInternal(Object val) { - return value == (long) val; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value; - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(value); - case Int8: - return ClickHouseConstant.createInt8Constant(value); - case UInt16: - return ClickHouseConstant.createUInt16Constant(value); - case Int16: - return ClickHouseConstant.createInt16Constant(value); - case UInt32: - return ClickHouseConstant.createUInt32Constant(value); - case Int32: - return ClickHouseConstant.createInt32Constant(value); - case UInt64: - return ClickHouseConstant.createUInt64Constant(BigInteger.valueOf(value)); - case Int64: - return ClickHouseConstant.createInt64Constant(BigInteger.valueOf(value)); - case Float32: - return ClickHouseConstant.createFloat32Constant((float) value); - case Float64: - return ClickHouseConstant.createFloat64Constant(value); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseInt16Constant extends ClickHouseConstant { - - private final long value; - - public ClickHouseInt16Constant(long value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != 0; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.Int16; - } - - @Override - public boolean compareInternal(Object val) { - return value == (long) val; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value; - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(value); - case Int8: - return ClickHouseConstant.createInt8Constant(value); - case UInt16: - return ClickHouseConstant.createUInt16Constant(value); - case Int16: - return ClickHouseConstant.createInt16Constant(value); - case UInt32: - return ClickHouseConstant.createUInt32Constant(value); - case Int32: - return ClickHouseConstant.createInt32Constant(value); - case UInt64: - return ClickHouseConstant.createUInt64Constant(BigInteger.valueOf(value)); - case Int64: - return ClickHouseConstant.createInt64Constant(BigInteger.valueOf(value)); - case Float32: - return ClickHouseConstant.createFloat32Constant((float) value); - case Float64: - return ClickHouseConstant.createFloat64Constant(value); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseUInt32Constant extends ClickHouseConstant { - - private final long value; - - public ClickHouseUInt32Constant(long value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != 0; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.UInt32; - } - - @Override - public boolean compareInternal(Object val) { - return value == (long) val; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value; - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(value); - case Int8: - return ClickHouseConstant.createInt8Constant(value); - case UInt16: - return ClickHouseConstant.createUInt16Constant(value); - case Int16: - return ClickHouseConstant.createInt16Constant(value); - case UInt32: - return ClickHouseConstant.createUInt32Constant(value); - case Int32: - return ClickHouseConstant.createInt32Constant(value); - case UInt64: - return ClickHouseConstant.createUInt64Constant(BigInteger.valueOf(value)); - case Int64: - return ClickHouseConstant.createInt64Constant(BigInteger.valueOf(value)); - case Float32: - return ClickHouseConstant.createFloat32Constant((float) value); - case Float64: - return ClickHouseConstant.createFloat64Constant(value); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseInt32Constant extends ClickHouseConstant { - - private final long value; - - public ClickHouseInt32Constant(long value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != 0; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.Int32; - } - - @Override - public boolean compareInternal(Object val) { - return value == (long) val; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value; - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(value); - case Int8: - return ClickHouseConstant.createInt8Constant(value); - case UInt16: - return ClickHouseConstant.createUInt16Constant(value); - case Int16: - return ClickHouseConstant.createInt16Constant(value); - case UInt32: - return ClickHouseConstant.createUInt32Constant(value); - case Int32: - return ClickHouseConstant.createInt32Constant(value); - case UInt64: - return ClickHouseConstant.createUInt64Constant(BigInteger.valueOf(value)); - case Int64: - return ClickHouseConstant.createInt64Constant(BigInteger.valueOf(value)); - case Float32: - return ClickHouseConstant.createFloat32Constant((float) value); - case Float64: - return ClickHouseConstant.createFloat64Constant(value); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseUInt64Constant extends ClickHouseConstant { - - private final BigInteger value; - - public ClickHouseUInt64Constant(BigInteger value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != BigInteger.ZERO; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.UInt64; - } - - @Override - public boolean compareInternal(Object val) { - return value.compareTo((BigInteger) val) == 0; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value.longValueExact(); - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - long val = value.longValueExact(); - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(val); - case Int8: - return ClickHouseConstant.createInt8Constant(val); - case UInt16: - return ClickHouseConstant.createUInt16Constant(val); - case Int16: - return ClickHouseConstant.createInt16Constant(val); - case UInt32: - return ClickHouseConstant.createUInt32Constant(val); - case Int32: - return ClickHouseConstant.createInt32Constant(val); - case UInt64: - return ClickHouseConstant.createUInt64Constant(value); - case Int64: - return ClickHouseConstant.createInt64Constant(value); - case Float32: - return ClickHouseConstant.createFloat32Constant(value.floatValue()); - case Float64: - return ClickHouseConstant.createFloat64Constant(value.doubleValue()); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseInt64Constant extends ClickHouseConstant { - - private final BigInteger value; - - public ClickHouseInt64Constant(BigInteger value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != BigInteger.ZERO; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.Int64; - } - - @Override - public boolean compareInternal(Object val) { - return value.compareTo((BigInteger) val) == 0; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value.longValueExact(); - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - long val = value.longValueExact(); - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(val); - case Int8: - return ClickHouseConstant.createInt8Constant(val); - case UInt16: - return ClickHouseConstant.createUInt16Constant(val); - case Int16: - return ClickHouseConstant.createInt16Constant(val); - case UInt32: - return ClickHouseConstant.createUInt32Constant(val); - case Int32: - return ClickHouseConstant.createInt32Constant(val); - case UInt64: - return ClickHouseConstant.createUInt64Constant(value); - case Int64: - return ClickHouseConstant.createInt64Constant(value); - case Float32: - return ClickHouseConstant.createFloat32Constant(value.floatValue()); - case Float64: - return ClickHouseConstant.createFloat64Constant(value.doubleValue()); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseUInt128Constant extends ClickHouseConstant { - - private final BigInteger value; - - public ClickHouseUInt128Constant(BigInteger value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != BigInteger.ZERO; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.UInt128; - } - - @Override - public boolean compareInternal(Object val) { - return value.compareTo((BigInteger) val) == 0; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value.longValueExact(); - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - long val = value.longValueExact(); - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(val); - case Int8: - return ClickHouseConstant.createInt8Constant(val); - case UInt16: - return ClickHouseConstant.createUInt16Constant(val); - case Int16: - return ClickHouseConstant.createInt16Constant(val); - case UInt32: - return ClickHouseConstant.createUInt32Constant(val); - case Int32: - return ClickHouseConstant.createInt32Constant(val); - case UInt64: - return ClickHouseConstant.createUInt64Constant(value); - case Int64: - return ClickHouseConstant.createInt64Constant(value); - case UInt128: - return ClickHouseConstant.createUInt128Constant(value); - case Int128: - return ClickHouseConstant.createInt128Constant(value); - case UInt256: - return ClickHouseConstant.createUInt256Constant(value); - case Int256: - return ClickHouseConstant.createInt256Constant(value); - case Float32: - return ClickHouseConstant.createFloat32Constant(value.floatValue()); - case Float64: - return ClickHouseConstant.createFloat64Constant(value.doubleValue()); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseInt128Constant extends ClickHouseConstant { - - private final BigInteger value; - - public ClickHouseInt128Constant(BigInteger value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != BigInteger.ZERO; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.Int128; - } - - @Override - public boolean compareInternal(Object val) { - return value.compareTo((BigInteger) val) == 0; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value.longValueExact(); - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - long val = value.longValueExact(); - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(val); - case Int8: - return ClickHouseConstant.createInt8Constant(val); - case UInt16: - return ClickHouseConstant.createUInt16Constant(val); - case Int16: - return ClickHouseConstant.createInt16Constant(val); - case UInt32: - return ClickHouseConstant.createUInt32Constant(val); - case Int32: - return ClickHouseConstant.createInt32Constant(val); - case UInt64: - return ClickHouseConstant.createUInt64Constant(value); - case Int64: - return ClickHouseConstant.createInt64Constant(value); - case UInt128: - return ClickHouseConstant.createUInt128Constant(value); - case Int128: - return ClickHouseConstant.createInt128Constant(value); - case UInt256: - return ClickHouseConstant.createUInt256Constant(value); - case Int256: - return ClickHouseConstant.createInt256Constant(value); - case Float32: - return ClickHouseConstant.createFloat32Constant(value.floatValue()); - case Float64: - return ClickHouseConstant.createFloat64Constant(value.doubleValue()); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseUInt256Constant extends ClickHouseConstant { - - private final BigInteger value; - - public ClickHouseUInt256Constant(BigInteger value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != BigInteger.ZERO; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.UInt256; - } - - @Override - public boolean compareInternal(Object val) { - return value.compareTo((BigInteger) val) == 0; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value.longValueExact(); - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - long val = value.longValueExact(); - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(val); - case Int8: - return ClickHouseConstant.createInt8Constant(val); - case UInt16: - return ClickHouseConstant.createUInt16Constant(val); - case Int16: - return ClickHouseConstant.createInt16Constant(val); - case UInt32: - return ClickHouseConstant.createUInt32Constant(val); - case Int32: - return ClickHouseConstant.createInt32Constant(val); - case UInt64: - return ClickHouseConstant.createUInt64Constant(value); - case Int64: - return ClickHouseConstant.createInt64Constant(value); - case UInt128: - return ClickHouseConstant.createUInt128Constant(value); - case Int128: - return ClickHouseConstant.createInt128Constant(value); - case UInt256: - return ClickHouseConstant.createUInt256Constant(value); - case Int256: - return ClickHouseConstant.createInt256Constant(value); - case Float32: - return ClickHouseConstant.createFloat32Constant(value.floatValue()); - case Float64: - return ClickHouseConstant.createFloat64Constant(value.doubleValue()); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseInt256Constant extends ClickHouseConstant { - - private final BigInteger value; - - public ClickHouseInt256Constant(BigInteger value) { - this.value = value; - } - - @Override - public String toString() { - return String.valueOf(value); - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public boolean asBooleanNotNull() { - return value != BigInteger.ZERO; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.Int256; - } - - @Override - public boolean compareInternal(Object val) { - return value.compareTo((BigInteger) val) == 0; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asInt() < right.asInt() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public long asInt() { - return value.longValueExact(); - } - - @Override - public Object getValue() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - long val = value.longValueExact(); - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(val); - case Int8: - return ClickHouseConstant.createInt8Constant(val); - case UInt16: - return ClickHouseConstant.createUInt16Constant(val); - case Int16: - return ClickHouseConstant.createInt16Constant(val); - case UInt32: - return ClickHouseConstant.createUInt32Constant(val); - case Int32: - return ClickHouseConstant.createInt32Constant(val); - case UInt64: - return ClickHouseConstant.createUInt64Constant(value); - case Int64: - return ClickHouseConstant.createInt64Constant(value); - case UInt128: - return ClickHouseConstant.createUInt128Constant(value); - case Int128: - return ClickHouseConstant.createInt128Constant(value); - case UInt256: - return ClickHouseConstant.createUInt256Constant(value); - case Int256: - return ClickHouseConstant.createInt256Constant(value); - case Float32: - return ClickHouseConstant.createFloat32Constant(value.floatValue()); - case Float64: - return ClickHouseConstant.createFloat64Constant(value.doubleValue()); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseFloat32Constant extends ClickHouseConstant { - - private final float value; - - public ClickHouseFloat32Constant(float value) { - this.value = value; - } - - @Override - public Object getValue() { - return value; - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public String toString() { - if (value == Double.POSITIVE_INFINITY) { - return "'+Inf'"; - } else if (value == Double.NEGATIVE_INFINITY) { - return "'-Inf'"; - } - return String.valueOf(value); - } - - @Override - public boolean compareInternal(Object val) { - return Float.compare(value, (float) val) == 0; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asDouble() < right.asDouble() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - ClickHouseConstant converted = right.cast(ClickHouseDataType.Float32); - return this.asDouble() < converted.asDouble() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - - @Override - public boolean asBooleanNotNull() { - return Float.compare(value, (float) 0) == 0; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.Float32; - } - - @Override - public double asDouble() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant((long) value); - case Int8: - return ClickHouseConstant.createInt8Constant((long) value); - case UInt16: - return ClickHouseConstant.createUInt16Constant((long) value); - case Int16: - return ClickHouseConstant.createInt16Constant((long) value); - case UInt32: - return ClickHouseConstant.createUInt32Constant((long) value); - case Int32: - return ClickHouseConstant.createInt32Constant((long) value); - case UInt64: - return ClickHouseConstant.createUInt64Constant(BigInteger.valueOf((long) value)); - case Int64: - return ClickHouseConstant.createInt64Constant(BigInteger.valueOf((long) value)); - case Float32: - return ClickHouseConstant.createFloat32Constant(value); - case Float64: - return ClickHouseConstant.createFloat64Constant(value); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseFloat64Constant extends ClickHouseConstant { - - private final double value; - - public ClickHouseFloat64Constant(double value) { - this.value = value; - } - - @Override - public Object getValue() { - return value; - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public String toString() { - if (value == Double.POSITIVE_INFINITY) { - return "'+Inf'"; - } else if (value == Double.NEGATIVE_INFINITY) { - return "'-Inf'"; - } - return String.valueOf(value); - } - - @Override - public boolean compareInternal(Object val) { - return Double.compare(value, (double) val) == 0; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asDouble() < right.asDouble() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - ClickHouseConstant converted = right.cast(ClickHouseDataType.Float64); - return this.asDouble() < converted.asDouble() ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - - @Override - public boolean asBooleanNotNull() { - return Double.compare(value, 0.0) == 0; - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.Float64; - } - - @Override - public double asDouble() { - return value; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant((long) value); - case Int8: - return ClickHouseConstant.createInt8Constant((long) value); - case UInt16: - return ClickHouseConstant.createUInt16Constant((long) value); - case Int16: - return ClickHouseConstant.createInt16Constant((long) value); - case UInt32: - return ClickHouseConstant.createUInt32Constant((long) value); - case Int32: - return ClickHouseConstant.createInt32Constant((long) value); - case UInt64: - return ClickHouseConstant.createUInt64Constant(BigInteger.valueOf((long) value)); - case Int64: - return ClickHouseConstant.createInt64Constant(BigInteger.valueOf((long) value)); - case Float32: - return ClickHouseConstant.createFloat32Constant((float) value); - case Float64: - return ClickHouseConstant.createFloat64Constant(value); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static class ClickHouseStringConstant extends ClickHouseConstant { - - private final String value; - - public ClickHouseStringConstant(String value) { - this.value = value; - } - - @Override - public boolean isNull() { - return false; - } - - @Override - public Object getValue() { - return value; - } - - @Override - public String toString() { - return "'" + value.replace("\\", "\\\\").replace("'", "\\'") + "'"; - } - - @Override - public boolean asBooleanNotNull() { - return value.length() > 0; - } - - @Override - public boolean compareInternal(Object val) { - return value.compareTo((String) val) == 0; - } - - @Override - public ClickHouseConstant applyLess(ClickHouseConstant right) { - if (this.getDataType() == right.getDataType()) { - return this.asString().compareTo(right.asString()) <= 0 ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); - } - throw new IgnoreMeException(); - } - - @Override - public ClickHouseDataType getDataType() { - return ClickHouseDataType.String; - } - - @Override - public ClickHouseConstant cast(ClickHouseDataType type) { - switch (type) { - case String: - return ClickHouseConstant.createStringConstant(this.toString()); - case UInt8: - return ClickHouseConstant.createUInt8Constant(Integer.parseInt(value)); - case Int8: - return ClickHouseConstant.createInt8Constant(Integer.parseInt(value)); - case UInt16: - return ClickHouseConstant.createUInt16Constant(Integer.parseInt(value)); - case Int16: - return ClickHouseConstant.createInt16Constant(Integer.parseInt(value)); - case UInt32: - return ClickHouseConstant.createUInt32Constant(Integer.parseInt(value)); - case Int32: - return ClickHouseConstant.createInt32Constant(Integer.parseInt(value)); - case UInt64: - return ClickHouseConstant.createUInt64Constant(BigInteger.valueOf(Integer.parseInt(value))); - case Int64: - return ClickHouseConstant.createInt64Constant(BigInteger.valueOf(Integer.parseInt(value))); - case Float32: - return ClickHouseConstant.createFloat32Constant((float) Float.parseFloat(value)); - case Float64: - return ClickHouseConstant.createFloat64Constant((double) Double.parseDouble(value)); - case Nothing: - return ClickHouseConstant.createNullConstant(); - case IntervalYear: - case IntervalQuarter: - case IntervalMonth: - case IntervalWeek: - case IntervalDay: - case IntervalHour: - case IntervalMinute: - case IntervalSecond: - case Date: - case DateTime: - case Enum8: - case Enum16: - case Decimal32: - case Decimal64: - case Decimal128: - case Decimal: - case UUID: - case FixedString: - case Nested: - case Tuple: - case Array: - case AggregateFunction: - case Unknown: - default: - throw new AssertionError(type); - } - } - } - - public static ClickHouseConstant createStringConstant(String text) { - return new ClickHouseStringConstant(text); - } - - public static ClickHouseConstant createFloat64Constant(double val) { - return new ClickHouseFloat64Constant(val); - } - - public static ClickHouseConstant createFloat32Constant(float val) { - return new ClickHouseFloat32Constant(val); - } - - public static ClickHouseConstant createIntConstant(ClickHouseDataType type, long val) { - switch (type) { - case IntervalYear: - break; - case IntervalQuarter: - break; - case IntervalMonth: - break; - case IntervalWeek: - break; - case IntervalDay: - break; - case IntervalHour: - break; - case IntervalMinute: - break; - case IntervalSecond: - break; - case UInt256: - return createUInt256Constant(BigInteger.valueOf(val)); - case UInt128: - return createUInt128Constant(BigInteger.valueOf(val)); - case UInt64: - return createUInt64Constant(BigInteger.valueOf(val)); - case UInt32: - return createUInt32Constant(val); - case UInt16: - return createUInt16Constant(val); - case UInt8: - return createUInt8Constant(val); - case Int256: - return createInt256Constant(BigInteger.valueOf(val)); - case Int128: - return createInt256Constant(BigInteger.valueOf(val)); - case Int64: - return createInt64Constant(BigInteger.valueOf(val)); - case Int32: - return createInt32Constant(val); - case Int16: - return createInt16Constant(val); - case Int8: - return createInt8Constant(val); - case Date: - break; - case DateTime: - break; - case Enum8: - break; - case Enum16: - break; - case Float32: - break; - case Float64: - break; - case Decimal32: - break; - case Decimal64: - break; - case Decimal128: - break; - case Decimal: - break; - case UUID: - break; - case String: - break; - case FixedString: - break; - case Nothing: - break; - case Nested: - break; - case Tuple: - break; - case Array: - break; - case AggregateFunction: - break; - case Unknown: - break; - default: - break; - } - throw new AssertionError(type); - } - - public static ClickHouseConstant createInt256Constant(BigInteger val) { - return new ClickHouseInt256Constant(val); - } - - public static ClickHouseConstant createUInt256Constant(BigInteger val) { - return new ClickHouseUInt256Constant(val); - } - - public static ClickHouseConstant createInt128Constant(BigInteger val) { - return new ClickHouseInt128Constant(val); - } - - public static ClickHouseConstant createUInt128Constant(BigInteger val) { - return new ClickHouseUInt128Constant(val); - } - - public static ClickHouseConstant createInt64Constant(BigInteger val) { - return new ClickHouseInt64Constant(val); - } - - public static ClickHouseConstant createUInt64Constant(BigInteger val) { - return new ClickHouseUInt64Constant(val); - } - - public static ClickHouseConstant createInt32Constant(long val) { - return new ClickHouseInt32Constant(val); - } - - public static ClickHouseConstant createUInt32Constant(long val) { - return new ClickHouseUInt32Constant(val); - } - - public static ClickHouseConstant createUInt16Constant(long val) { - return new ClickHouseUInt16Constant(val); - } - - public static ClickHouseConstant createInt16Constant(long val) { - return new ClickHouseInt16Constant(val); - } - - public static ClickHouseConstant createUInt8Constant(long val) { - return new ClickHouseUInt8Constant((int) val); - } - - public static ClickHouseConstant createInt8Constant(long val) { - return new ClickHouseInt8Constant((int) val); - } - public abstract boolean isNull(); - public static ClickHouseConstant createNullConstant() { - return new ClickHouseNullConstant(); - } - - public static ClickHouseConstant createTrue() { - return new ClickHouseUInt8Constant(1); - } - - public static ClickHouseConstant createFalse() { - return new ClickHouseUInt8Constant(0); - } - - public static ClickHouseConstant createBoolean(boolean val) { - return val ? createTrue() : createFalse(); - } - public abstract ClickHouseConstant cast(ClickHouseDataType type); public abstract boolean asBooleanNotNull(); @@ -1904,8 +18,8 @@ public static ClickHouseConstant createBoolean(boolean val) { public ClickHouseConstant applyEquals(ClickHouseConstant right) { if (this.getDataType() == right.getDataType()) { - return this.compareInternal(right.getValue()) ? ClickHouseConstant.createTrue() - : ClickHouseConstant.createFalse(); + return this.compareInternal(right.getValue()) ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); } else { ClickHouseConstant converted = right.cast(this.getDataType()); return this.applyEquals(converted); diff --git a/src/sqlancer/clickhouse/ast/ClickHouseExpression.java b/src/sqlancer/clickhouse/ast/ClickHouseExpression.java index c3642c208..eff88d012 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseExpression.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseExpression.java @@ -1,9 +1,13 @@ package sqlancer.clickhouse.ast; -import sqlancer.clickhouse.ClickHouseSchema; +import sqlancer.clickhouse.ClickHouseSchema.ClickHouseColumn; +import sqlancer.clickhouse.ClickHouseSchema.ClickHouseTable; +import sqlancer.common.ast.newast.Expression; +import sqlancer.common.ast.newast.Join; +import sqlancer.common.visitor.BinaryOperation; import sqlancer.common.visitor.UnaryOperation; -public abstract class ClickHouseExpression { +public abstract class ClickHouseExpression implements Expression { public ClickHouseConstant getExpectedValue() { return null; @@ -35,35 +39,71 @@ public ClickHouseExpression getExpression() { } } - public static class ClickHouseJoin extends ClickHouseExpression { + public static class ClickHouseJoinOnClause extends ClickHouseExpression + implements BinaryOperation { + private final ClickHouseExpression left; + private final ClickHouseExpression right; + public ClickHouseJoinOnClause(ClickHouseExpression left, ClickHouseExpression right) { + this.left = left; + this.right = right; + } + + @Override + public final ClickHouseExpression getLeft() { + return this.left; + } + + @Override + public final ClickHouseExpression getRight() { + return this.right; + } + + @Override + public String getOperatorRepresentation() { + return "="; + } + } + + public static class ClickHouseJoin extends ClickHouseExpression + implements Join { // TODO: support ANY, ALL, ASOF modifiers + // LEFT_SEMI, RIGHT_SEMI are not deterministic as ClickHouse allows to read columns from + // whitelist table as well public enum JoinType { - INNER, CROSS, LEFT_OUTER, RIGHT_OUTER, FULL_OUTER, NATURAL, LEFT_SEMI, RIGHT_SEMI, LEFT_ANTI, RIGHT_ANTI; + INNER, CROSS, LEFT_OUTER, RIGHT_OUTER, FULL_OUTER, LEFT_ANTI, RIGHT_ANTI; } - private final ClickHouseSchema.ClickHouseTable table; + private final ClickHouseTableReference leftTable; + private final ClickHouseTableReference rightTable; private ClickHouseExpression onClause; private final ClickHouseJoin.JoinType type; - public ClickHouseJoin(ClickHouseSchema.ClickHouseTable table, ClickHouseExpression onClause, - ClickHouseJoin.JoinType type) { - this.table = table; + public ClickHouseJoin(ClickHouseTableReference leftTable, ClickHouseTableReference rightTable, + ClickHouseJoin.JoinType type, ClickHouseJoinOnClause onClause) { + this.leftTable = leftTable; + this.rightTable = rightTable; this.onClause = onClause; this.type = type; } - public ClickHouseJoin(ClickHouseSchema.ClickHouseTable table, ClickHouseJoin.JoinType type) { - this.table = table; - if (type != ClickHouseJoin.JoinType.NATURAL) { + public ClickHouseJoin(ClickHouseTableReference leftTable, ClickHouseTableReference rightTable, + ClickHouseJoin.JoinType type) { + this.leftTable = leftTable; + this.rightTable = rightTable; + if (type != ClickHouseJoin.JoinType.CROSS) { throw new AssertionError(); } this.onClause = null; this.type = type; } - public ClickHouseSchema.ClickHouseTable getTable() { - return table; + public ClickHouseTableReference getLeftTable() { + return leftTable; + } + + public ClickHouseTableReference getRightTable() { + return rightTable; } public ClickHouseExpression getOnClause() { @@ -74,6 +114,7 @@ public ClickHouseJoin.JoinType getType() { return type; } + @Override public void setOnClause(ClickHouseExpression onClause) { this.onClause = onClause; } diff --git a/src/sqlancer/clickhouse/ast/ClickHouseSelect.java b/src/sqlancer/clickhouse/ast/ClickHouseSelect.java index 7a72f2919..61aeffa4d 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseSelect.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseSelect.java @@ -3,10 +3,16 @@ import java.util.Collections; import java.util.List; -public class ClickHouseSelect extends ClickHouseExpression { +import sqlancer.clickhouse.ClickHouseSchema.ClickHouseColumn; +import sqlancer.clickhouse.ClickHouseSchema.ClickHouseTable; +import sqlancer.clickhouse.ClickHouseToStringVisitor; +import sqlancer.common.ast.newast.Select; + +public class ClickHouseSelect extends ClickHouseExpression implements + Select { private ClickHouseSelect.SelectType fromOptions = ClickHouseSelect.SelectType.ALL; - private List fromList = Collections.emptyList(); + private List fromClauses; private ClickHouseExpression whereClause; private List groupByClause = Collections.emptyList(); private ClickHouseExpression limitClause; @@ -24,8 +30,13 @@ public void setSelectType(ClickHouseSelect.SelectType fromOptions) { this.setFromOptions(fromOptions); } - public void setFromTables(List fromTables) { - this.setFromList(fromTables); + public void setFromClause(ClickHouseExpression fromList) { + this.fromClauses = List.of(fromList); + } + + @Override + public List getFromList() { + return fromClauses; } public ClickHouseSelect.SelectType getFromOptions() { @@ -36,76 +47,94 @@ public void setFromOptions(ClickHouseSelect.SelectType fromOptions) { this.fromOptions = fromOptions; } - public List getFromList() { - return fromList; - } - - public void setFromList(List fromList) { - this.fromList = fromList; - } - + @Override public ClickHouseExpression getWhereClause() { return whereClause; } + @Override public void setWhereClause(ClickHouseExpression whereClause) { this.whereClause = whereClause; } + @Override public void setGroupByClause(List groupByClause) { this.groupByClause = groupByClause; } + @Override public List getGroupByClause() { return groupByClause; } + @Override public void setLimitClause(ClickHouseExpression limitClause) { this.limitClause = limitClause; } + @Override public ClickHouseExpression getLimitClause() { return limitClause; } - public List getOrderByClause() { + @Override + public List getOrderByClauses() { return orderByClause; } - public void setOrderByExpressions(List orderBy) { + @Override + public void setOrderByClauses(List orderBy) { this.orderByClause = orderBy; } + @Override public void setOffsetClause(ClickHouseExpression offsetClause) { this.offsetClause = offsetClause; } + @Override public ClickHouseExpression getOffsetClause() { return offsetClause; } + @Override public void setFetchColumns(List fetchColumns) { this.fetchColumns = fetchColumns; } + @Override public List getFetchColumns() { return fetchColumns; } + @Override public void setJoinClauses(List joinStatements) { this.joinStatements = joinStatements; } + @Override public List getJoinClauses() { return joinStatements; } + @Override public void setHavingClause(ClickHouseExpression havingClause) { this.havingClause = havingClause; } + @Override public ClickHouseExpression getHavingClause() { assert orderByClause != null; return havingClause; } + + @Override + public String asString() { + return ClickHouseToStringVisitor.asString(this); + } + + @Override + public void setFromList(List fromList) { + this.fromClauses = fromList; + } } diff --git a/src/sqlancer/clickhouse/ast/ClickHouseTableReference.java b/src/sqlancer/clickhouse/ast/ClickHouseTableReference.java index 415dfca22..406ffb15c 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseTableReference.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseTableReference.java @@ -1,17 +1,34 @@ package sqlancer.clickhouse.ast; +import java.util.List; +import java.util.stream.Collectors; + import sqlancer.clickhouse.ClickHouseSchema.ClickHouseTable; public class ClickHouseTableReference extends ClickHouseExpression { private final ClickHouseTable table; + private final String alias; - public ClickHouseTableReference(ClickHouseTable table) { + public ClickHouseTableReference(ClickHouseTable table, String alias) { this.table = table; + this.alias = alias; } public ClickHouseTable getTable() { return table; } + public String getTableName() { + return (alias == null) ? table.getName() : alias; + } + + public String getAlias() { + return alias; + } + + public List getColumnReferences() { + return this.table.getColumns().stream().map(c -> c.asColumnReference(this.alias)).collect(Collectors.toList()); + } + } diff --git a/src/sqlancer/clickhouse/ast/ClickHouseUnaryFunctionOperation.java b/src/sqlancer/clickhouse/ast/ClickHouseUnaryFunctionOperation.java new file mode 100644 index 000000000..30707354f --- /dev/null +++ b/src/sqlancer/clickhouse/ast/ClickHouseUnaryFunctionOperation.java @@ -0,0 +1,52 @@ +package sqlancer.clickhouse.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.visitor.UnaryOperation; + +public class ClickHouseUnaryFunctionOperation extends ClickHouseExpression + implements UnaryOperation { + private final ClickHouseUnaryFunctionOperation.ClickHouseUnaryFunctionOperator operator; + private final ClickHouseExpression expression; + + public ClickHouseUnaryFunctionOperation(ClickHouseExpression expression, ClickHouseUnaryFunctionOperator operator) { + this.operator = operator; + this.expression = expression; + } + + public enum ClickHouseUnaryFunctionOperator implements Operator { + EXP("exp"), SQRT("sqrt"), ERF("erf"), SIN("sin"), COS("cos"), TAN("tan"), SIGN("sign"), RADIANS("radians"), + LOG("log"), ABS("abs"); + + private String textRepresentation; + + ClickHouseUnaryFunctionOperator(String text) { + this.textRepresentation = text; + } + + public static ClickHouseUnaryFunctionOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + } + + @Override + public ClickHouseExpression getExpression() { + return expression; + } + + @Override + public String getOperatorRepresentation() { + return operator.getTextRepresentation(); + } + + @Override + public OperatorKind getOperatorKind() { + return OperatorKind.PREFIX; + } + +} diff --git a/src/sqlancer/clickhouse/ast/ClickHouseUnaryPostfixOperation.java b/src/sqlancer/clickhouse/ast/ClickHouseUnaryPostfixOperation.java index 553f85070..444818413 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseUnaryPostfixOperation.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseUnaryPostfixOperation.java @@ -1,6 +1,7 @@ package sqlancer.clickhouse.ast; import sqlancer.Randomly; +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; import sqlancer.common.ast.BinaryOperatorNode; import sqlancer.common.visitor.UnaryOperation; @@ -86,7 +87,7 @@ public ClickHouseConstant getExpectedValue() { if (negate) { val = !val; } - return ClickHouseConstant.createInt32Constant(val ? 1 : 0); + return ClickHouseCreateConstant.createInt32Constant(val ? 1 : 0); } } diff --git a/src/sqlancer/clickhouse/ast/ClickHouseUnaryPrefixOperation.java b/src/sqlancer/clickhouse/ast/ClickHouseUnaryPrefixOperation.java index 972136e64..15ab8fb36 100644 --- a/src/sqlancer/clickhouse/ast/ClickHouseUnaryPrefixOperation.java +++ b/src/sqlancer/clickhouse/ast/ClickHouseUnaryPrefixOperation.java @@ -1,7 +1,9 @@ package sqlancer.clickhouse.ast; -import ru.yandex.clickhouse.domain.ClickHouseDataType; +import com.clickhouse.client.ClickHouseDataType; + import sqlancer.Randomly; +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; import sqlancer.common.ast.BinaryOperatorNode.Operator; import sqlancer.common.visitor.UnaryOperation; @@ -20,10 +22,10 @@ public enum ClickHouseUnaryPrefixOperator implements Operator { @Override public ClickHouseConstant apply(ClickHouseConstant constant) { if (constant.getDataType() == ClickHouseDataType.Nothing) { - return ClickHouseConstant.createNullConstant(); + return ClickHouseCreateConstant.createNullConstant(); } else { - return constant.asBooleanNotNull() ? ClickHouseConstant.createFalse() - : ClickHouseConstant.createTrue(); + return constant.asBooleanNotNull() ? ClickHouseCreateConstant.createFalse() + : ClickHouseCreateConstant.createTrue(); } } }, @@ -31,7 +33,7 @@ public ClickHouseConstant apply(ClickHouseConstant constant) { @Override public ClickHouseConstant apply(ClickHouseConstant constant) { if (constant.getDataType() == ClickHouseDataType.Int32) { - return ClickHouseConstant.createInt32Constant(-constant.asInt()); + return ClickHouseCreateConstant.createInt32Constant(-constant.asInt()); } throw new AssertionError(constant); } diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseBooleanConstant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseBooleanConstant.java new file mode 100644 index 000000000..d5164162c --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseBooleanConstant.java @@ -0,0 +1,142 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseBooleanConstant extends ClickHouseConstant { + + private final boolean value; + + public ClickHouseBooleanConstant(boolean value) { + this.value = value; + } + + @Override + public String toString() { + return value ? "true" : "false"; + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Bool; + } + + @Override + public boolean compareInternal(Object val) { + if (val instanceof Boolean) { + return value == ((Boolean) val).booleanValue(); + } else { + return value == ((int) val != 0); + } + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value ? 1 : 0; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(asInt()); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(asInt()); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(asInt()); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(asInt()); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(asInt()); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(asInt()); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf(asInt())); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf(asInt())); + case UInt128: + return ClickHouseCreateConstant.createUInt128Constant(BigInteger.valueOf(asInt())); + case Int128: + return ClickHouseCreateConstant.createInt128Constant(BigInteger.valueOf(asInt())); + case UInt256: + return ClickHouseCreateConstant.createUInt256Constant(BigInteger.valueOf(asInt())); + case Int256: + return ClickHouseCreateConstant.createInt256Constant(BigInteger.valueOf(asInt())); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant((float) asInt()); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant((double) asInt()); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(asInt() != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseCreateConstant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseCreateConstant.java new file mode 100644 index 000000000..dc87619a8 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseCreateConstant.java @@ -0,0 +1,162 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.clickhouse.ast.ClickHouseConstant; +import sqlancer.clickhouse.ast.ClickHouseExpression; + +public abstract class ClickHouseCreateConstant extends ClickHouseExpression { + + public static ClickHouseConstant createStringConstant(String text) { + return new ClickHouseStringConstant(text); + } + + public static ClickHouseConstant createFloat64Constant(double val) { + return new ClickHouseFloat64Constant(val); + } + + public static ClickHouseConstant createFloat32Constant(float val) { + return new ClickHouseFloat32Constant(val); + } + + public static ClickHouseConstant createInt256Constant(BigInteger val) { + return new ClickHouseInt256Constant(val); + } + + public static ClickHouseConstant createUInt256Constant(BigInteger val) { + return new ClickHouseUInt256Constant(val); + } + + public static ClickHouseConstant createInt128Constant(BigInteger val) { + return new ClickHouseInt128Constant(val); + } + + public static ClickHouseConstant createUInt128Constant(BigInteger val) { + return new ClickHouseUInt128Constant(val); + } + + public static ClickHouseConstant createInt64Constant(BigInteger val) { + return new ClickHouseInt64Constant(val); + } + + public static ClickHouseConstant createUInt64Constant(BigInteger val) { + return new ClickHouseUInt64Constant(val); + } + + public static ClickHouseConstant createInt32Constant(long val) { + return new ClickHouseInt32Constant(val); + } + + public static ClickHouseConstant createUInt32Constant(long val) { + return new ClickHouseUInt32Constant(val); + } + + public static ClickHouseConstant createUInt16Constant(long val) { + return new ClickHouseUInt16Constant(val); + } + + public static ClickHouseConstant createInt16Constant(long val) { + return new ClickHouseInt16Constant(val); + } + + public static ClickHouseConstant createUInt8Constant(long val) { + return new ClickHouseUInt8Constant((int) val); + } + + public static ClickHouseConstant createInt8Constant(long val) { + return new ClickHouseInt8Constant((int) val); + } + + public static ClickHouseConstant createBooleanConstant(Boolean b) { + return new ClickHouseBooleanConstant(b); + } + + public static ClickHouseConstant createNullConstant() { + return new ClickHouseNullConstant(); + } + + public static ClickHouseConstant createTrue() { + return new ClickHouseBooleanConstant(true); + } + + public static ClickHouseConstant createFalse() { + return new ClickHouseBooleanConstant(false); + } + + public static ClickHouseConstant createBoolean(boolean val) { + return val ? createTrue() : createFalse(); + } + + public static ClickHouseConstant createIntConstant(ClickHouseDataType type, long val) { + switch (type) { + case UInt256: + return ClickHouseCreateConstant.createUInt256Constant(BigInteger.valueOf(val)); + case UInt128: + return ClickHouseCreateConstant.createUInt128Constant(BigInteger.valueOf(val)); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf(val)); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(val); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(val); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(val); + case Int256: + return ClickHouseCreateConstant.createInt256Constant(BigInteger.valueOf(val)); + case Int128: + return ClickHouseCreateConstant.createInt128Constant(BigInteger.valueOf(val)); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf(val)); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(val); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(val); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(val); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + case String: + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } + +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseFloat32Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseFloat32Constant.java new file mode 100644 index 000000000..1385790cb --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseFloat32Constant.java @@ -0,0 +1,136 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseFloat32Constant extends ClickHouseConstant { + + private final float value; + + public ClickHouseFloat32Constant(float value) { + this.value = value; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "'+Inf'"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "'-Inf'"; + } + return String.valueOf(value); + } + + @Override + public boolean compareInternal(Object val) { + return Float.compare(value, (float) val) == 0; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asDouble() < right.asDouble() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + ClickHouseConstant converted = right.cast(ClickHouseDataType.Float32); + return this.asDouble() < converted.asDouble() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + + @Override + public boolean asBooleanNotNull() { + return Float.compare(value, (float) 0) == 0; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Float32; + } + + @Override + public double asDouble() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant((long) value); + case Int8: + return ClickHouseCreateConstant.createInt8Constant((long) value); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant((long) value); + case Int16: + return ClickHouseCreateConstant.createInt16Constant((long) value); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant((long) value); + case Int32: + return ClickHouseCreateConstant.createInt32Constant((long) value); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf((long) value)); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf((long) value)); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant(value); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(value != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseFloat64Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseFloat64Constant.java new file mode 100644 index 000000000..9146faa39 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseFloat64Constant.java @@ -0,0 +1,136 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseFloat64Constant extends ClickHouseConstant { + + private final double value; + + public ClickHouseFloat64Constant(double value) { + this.value = value; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "'+Inf'"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "'-Inf'"; + } + return String.valueOf(value); + } + + @Override + public boolean compareInternal(Object val) { + return Double.compare(value, (double) val) == 0; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asDouble() < right.asDouble() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + ClickHouseConstant converted = right.cast(ClickHouseDataType.Float64); + return this.asDouble() < converted.asDouble() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + + @Override + public boolean asBooleanNotNull() { + return Double.compare(value, 0.0) == 0; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Float64; + } + + @Override + public double asDouble() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant((long) value); + case Int8: + return ClickHouseCreateConstant.createInt8Constant((long) value); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant((long) value); + case Int16: + return ClickHouseCreateConstant.createInt16Constant((long) value); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant((long) value); + case Int32: + return ClickHouseCreateConstant.createInt32Constant((long) value); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf((long) value)); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf((long) value)); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant((float) value); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(value != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseInt128Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt128Constant.java new file mode 100644 index 000000000..c388cb638 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt128Constant.java @@ -0,0 +1,139 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseInt128Constant extends ClickHouseConstant { + + private final BigInteger value; + + public ClickHouseInt128Constant(BigInteger value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != BigInteger.ZERO; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Int128; + } + + @Override + public boolean compareInternal(Object val) { + return value.compareTo((BigInteger) val) == 0; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value.longValueExact(); + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + long val = value.longValueExact(); + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(val); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(val); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(val); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(val); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(val); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(val); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(value); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(value); + case UInt128: + return ClickHouseCreateConstant.createUInt128Constant(value); + case Int128: + return ClickHouseCreateConstant.createInt128Constant(value); + case UInt256: + return ClickHouseCreateConstant.createUInt256Constant(value); + case Int256: + return ClickHouseCreateConstant.createInt256Constant(value); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant(value.floatValue()); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value.doubleValue()); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(val != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseInt16Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt16Constant.java new file mode 100644 index 000000000..1fde4dfb1 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt16Constant.java @@ -0,0 +1,130 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseInt16Constant extends ClickHouseConstant { + + private final long value; + + public ClickHouseInt16Constant(long value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != 0; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Int16; + } + + @Override + public boolean compareInternal(Object val) { + return value == (long) val; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(value); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(value); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(value); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(value); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(value); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(value); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf(value)); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf(value)); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant((float) value); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(value != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseInt256Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt256Constant.java new file mode 100644 index 000000000..f40ea5297 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt256Constant.java @@ -0,0 +1,139 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseInt256Constant extends ClickHouseConstant { + + private final BigInteger value; + + public ClickHouseInt256Constant(BigInteger value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != BigInteger.ZERO; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Int256; + } + + @Override + public boolean compareInternal(Object val) { + return value.compareTo((BigInteger) val) == 0; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value.longValueExact(); + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + long val = value.longValueExact(); + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(val); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(val); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(val); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(val); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(val); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(val); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(value); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(value); + case UInt128: + return ClickHouseCreateConstant.createUInt128Constant(value); + case Int128: + return ClickHouseCreateConstant.createInt128Constant(value); + case UInt256: + return ClickHouseCreateConstant.createUInt256Constant(value); + case Int256: + return ClickHouseCreateConstant.createInt256Constant(value); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant(value.floatValue()); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value.doubleValue()); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(val != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseInt32Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt32Constant.java new file mode 100644 index 000000000..408df18fc --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt32Constant.java @@ -0,0 +1,130 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseInt32Constant extends ClickHouseConstant { + + private final long value; + + public ClickHouseInt32Constant(long value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != 0; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Int32; + } + + @Override + public boolean compareInternal(Object val) { + return value == (long) val; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(value); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(value); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(value); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(value); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(value); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(value); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf(value)); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf(value)); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant((float) value); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(value != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseInt64Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt64Constant.java new file mode 100644 index 000000000..cb5c4b5f8 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt64Constant.java @@ -0,0 +1,131 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseInt64Constant extends ClickHouseConstant { + + private final BigInteger value; + + public ClickHouseInt64Constant(BigInteger value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != BigInteger.ZERO; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Int64; + } + + @Override + public boolean compareInternal(Object val) { + return value.compareTo((BigInteger) val) == 0; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value.longValueExact(); + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + long val = value.longValueExact(); + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(val); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(val); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(val); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(val); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(val); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(val); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(value); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(value); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant(value.floatValue()); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value.doubleValue()); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(val != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseInt8Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt8Constant.java new file mode 100644 index 000000000..38ad4878c --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseInt8Constant.java @@ -0,0 +1,130 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseInt8Constant extends ClickHouseConstant { + + private final int value; + + public ClickHouseInt8Constant(int value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != 0; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Int8; + } + + @Override + public boolean compareInternal(Object val) { + return value == (int) val; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(value); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(value); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(value); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(value); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(value); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(value); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf(value)); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf(value)); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant((float) value); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(value != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseNullConstant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseNullConstant.java new file mode 100644 index 000000000..08370d560 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseNullConstant.java @@ -0,0 +1,53 @@ +package sqlancer.clickhouse.ast.constant; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseNullConstant extends ClickHouseConstant { + + @Override + public String toString() { + return "NULL"; + } + + @Override + public boolean isNull() { + return true; + } + + @Override + public boolean asBooleanNotNull() { + throw new AssertionError(); + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.Nothing; + } + + @Override + public boolean compareInternal(Object value) { + return false; + } + + @Override + public ClickHouseConstant applyEquals(ClickHouseConstant right) { + return ClickHouseCreateConstant.createNullConstant(); + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + return ClickHouseCreateConstant.createNullConstant(); + } + + @Override + public Object getValue() { + return null; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + return null; + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseStringConstant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseStringConstant.java new file mode 100644 index 000000000..fa8ba11f5 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseStringConstant.java @@ -0,0 +1,125 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseStringConstant extends ClickHouseConstant { + + private final String value; + + public ClickHouseStringConstant(String value) { + this.value = value; + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public String toString() { + return "'" + value.replace("\\", "\\\\").replace("'", "\\'") + "'"; + } + + @Override + public boolean asBooleanNotNull() { + return !value.isEmpty(); + } + + @Override + public boolean compareInternal(Object val) { + return value.compareTo((String) val) == 0; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asString().compareTo(right.asString()) <= 0 ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.String; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(Integer.parseInt(value)); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(Integer.parseInt(value)); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(Integer.parseInt(value)); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(Integer.parseInt(value)); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(Integer.parseInt(value)); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(Integer.parseInt(value)); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf(Integer.parseInt(value))); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf(Integer.parseInt(value))); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant((float) Float.parseFloat(value)); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant((double) Double.parseDouble(value)); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(value == "true"); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt128Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt128Constant.java new file mode 100644 index 000000000..5952c967c --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt128Constant.java @@ -0,0 +1,139 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseUInt128Constant extends ClickHouseConstant { + + private final BigInteger value; + + public ClickHouseUInt128Constant(BigInteger value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != BigInteger.ZERO; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.UInt128; + } + + @Override + public boolean compareInternal(Object val) { + return value.compareTo((BigInteger) val) == 0; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value.longValueExact(); + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + long val = value.longValueExact(); + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(val); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(val); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(val); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(val); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(val); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(val); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(value); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(value); + case UInt128: + return ClickHouseCreateConstant.createUInt128Constant(value); + case Int128: + return ClickHouseCreateConstant.createInt128Constant(value); + case UInt256: + return ClickHouseCreateConstant.createUInt256Constant(value); + case Int256: + return ClickHouseCreateConstant.createInt256Constant(value); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant(value.floatValue()); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value.doubleValue()); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(val != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt16Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt16Constant.java new file mode 100644 index 000000000..c7ae84fdf --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt16Constant.java @@ -0,0 +1,130 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseUInt16Constant extends ClickHouseConstant { + + private final long value; + + public ClickHouseUInt16Constant(long value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != 0; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.UInt16; + } + + @Override + public boolean compareInternal(Object val) { + return value == (long) val; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(value); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(value); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(value); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(value); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(value); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(value); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf(value)); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf(value)); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant((float) value); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(value != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt256Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt256Constant.java new file mode 100644 index 000000000..95b333323 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt256Constant.java @@ -0,0 +1,139 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseUInt256Constant extends ClickHouseConstant { + + private final BigInteger value; + + public ClickHouseUInt256Constant(BigInteger value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != BigInteger.ZERO; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.UInt256; + } + + @Override + public boolean compareInternal(Object val) { + return value.compareTo((BigInteger) val) == 0; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value.longValueExact(); + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + long val = value.longValueExact(); + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(val); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(val); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(val); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(val); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(val); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(val); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(value); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(value); + case UInt128: + return ClickHouseCreateConstant.createUInt128Constant(value); + case Int128: + return ClickHouseCreateConstant.createInt128Constant(value); + case UInt256: + return ClickHouseCreateConstant.createUInt256Constant(value); + case Int256: + return ClickHouseCreateConstant.createInt256Constant(value); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant(value.floatValue()); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value.doubleValue()); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(val != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt32Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt32Constant.java new file mode 100644 index 000000000..8e4f5fd09 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt32Constant.java @@ -0,0 +1,130 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseUInt32Constant extends ClickHouseConstant { + + private final long value; + + public ClickHouseUInt32Constant(long value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != 0; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.UInt32; + } + + @Override + public boolean compareInternal(Object val) { + return value == (long) val; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(value); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(value); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(value); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(value); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(value); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(value); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf(value)); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf(value)); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant((float) value); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(value != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt64Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt64Constant.java new file mode 100644 index 000000000..cd3363850 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt64Constant.java @@ -0,0 +1,131 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseUInt64Constant extends ClickHouseConstant { + + private final BigInteger value; + + public ClickHouseUInt64Constant(BigInteger value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != BigInteger.ZERO; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.UInt64; + } + + @Override + public boolean compareInternal(Object val) { + return value.compareTo((BigInteger) val) == 0; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value.longValueExact(); + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + long val = value.longValueExact(); + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(val); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(val); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(val); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(val); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(val); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(val); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(value); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(value); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant(value.floatValue()); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant(value.doubleValue()); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(val != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt8Constant.java b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt8Constant.java new file mode 100644 index 000000000..de0d308c3 --- /dev/null +++ b/src/sqlancer/clickhouse/ast/constant/ClickHouseUInt8Constant.java @@ -0,0 +1,138 @@ +package sqlancer.clickhouse.ast.constant; + +import java.math.BigInteger; + +import com.clickhouse.client.ClickHouseDataType; + +import sqlancer.IgnoreMeException; +import sqlancer.clickhouse.ast.ClickHouseConstant; + +public class ClickHouseUInt8Constant extends ClickHouseConstant { + + private final int value; + + public ClickHouseUInt8Constant(int value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean isNull() { + return false; + } + + @Override + public boolean asBooleanNotNull() { + return value != 0; + } + + @Override + public ClickHouseDataType getDataType() { + return ClickHouseDataType.UInt8; + } + + @Override + public boolean compareInternal(Object val) { + return value == (int) val; + } + + @Override + public ClickHouseConstant applyLess(ClickHouseConstant right) { + if (this.getDataType() == right.getDataType()) { + return this.asInt() < right.asInt() ? ClickHouseCreateConstant.createTrue() + : ClickHouseCreateConstant.createFalse(); + } + throw new IgnoreMeException(); + } + + @Override + public long asInt() { + return value; + } + + @Override + public Object getValue() { + return value; + } + + @Override + public ClickHouseConstant cast(ClickHouseDataType type) { + switch (type) { + case String: + return ClickHouseCreateConstant.createStringConstant(this.toString()); + case UInt8: + return ClickHouseCreateConstant.createUInt8Constant(value); + case Int8: + return ClickHouseCreateConstant.createInt8Constant(value); + case UInt16: + return ClickHouseCreateConstant.createUInt16Constant(value); + case Int16: + return ClickHouseCreateConstant.createInt16Constant(value); + case UInt32: + return ClickHouseCreateConstant.createUInt32Constant(value); + case Int32: + return ClickHouseCreateConstant.createInt32Constant(value); + case UInt64: + return ClickHouseCreateConstant.createUInt64Constant(BigInteger.valueOf(value)); + case Int64: + return ClickHouseCreateConstant.createInt64Constant(BigInteger.valueOf(value)); + case UInt128: + return ClickHouseCreateConstant.createUInt128Constant(BigInteger.valueOf(value)); + case Int128: + return ClickHouseCreateConstant.createInt128Constant(BigInteger.valueOf(value)); + case UInt256: + return ClickHouseCreateConstant.createUInt256Constant(BigInteger.valueOf(value)); + case Int256: + return ClickHouseCreateConstant.createInt256Constant(BigInteger.valueOf(value)); + case Float32: + return ClickHouseCreateConstant.createFloat32Constant((float) value); + case Float64: + return ClickHouseCreateConstant.createFloat64Constant((double) value); + case Nothing: + return ClickHouseCreateConstant.createNullConstant(); + case Bool: + return ClickHouseCreateConstant.createBooleanConstant(value != 0); + case IntervalYear: + case IntervalQuarter: + case IntervalMonth: + case IntervalWeek: + case IntervalDay: + case IntervalHour: + case IntervalMinute: + case IntervalSecond: + case Date: + case Date32: + case DateTime: + case DateTime32: + case DateTime64: + case Decimal: + case Decimal32: + case Decimal64: + case Decimal128: + case Decimal256: + case UUID: + case Enum: + case Enum8: + case Enum16: + case IPv4: + case IPv6: + case FixedString: + case AggregateFunction: + case SimpleAggregateFunction: + case Array: + case Map: + case Nested: + case Tuple: + case Point: + case Polygon: + case MultiPolygon: + case Ring: + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/clickhouse/gen/ClickHouseColumnBuilder.java b/src/sqlancer/clickhouse/gen/ClickHouseColumnBuilder.java index 65eb65623..2584ff0b3 100644 --- a/src/sqlancer/clickhouse/gen/ClickHouseColumnBuilder.java +++ b/src/sqlancer/clickhouse/gen/ClickHouseColumnBuilder.java @@ -61,12 +61,11 @@ public String createColumn(String columnName, ClickHouseProvider.ClickHouseGloba if (allowMaterialized) { sb.append(" MATERIALIZED ("); sb.append( - ClickHouseVisitor - .asString(new ClickHouseExpressionGenerator(globalState) - .setColumns( - columns.stream().filter(p -> !p.getName().contentEquals(columnName)) - .collect(Collectors.toList())) - .generateExpression(dataType))); + ClickHouseVisitor.asString( + new ClickHouseExpressionGenerator(globalState).generateExpressionWithColumns( + columns.stream().filter(p -> !p.getName().contentEquals(columnName)) + .map(p -> p.asColumnReference(null)).collect(Collectors.toList()), + 2))); sb.append(")"); } break; diff --git a/src/sqlancer/clickhouse/gen/ClickHouseCommon.java b/src/sqlancer/clickhouse/gen/ClickHouseCommon.java index 612e807d5..deae2a9aa 100644 --- a/src/sqlancer/clickhouse/gen/ClickHouseCommon.java +++ b/src/sqlancer/clickhouse/gen/ClickHouseCommon.java @@ -29,7 +29,7 @@ public static List getTableRefs(List tableRefs = new ArrayList<>(); for (ClickHouseSchema.ClickHouseTable t : tables) { ClickHouseTableReference tableRef; - tableRef = new ClickHouseTableReference(t); + tableRef = new ClickHouseTableReference(t, null); tableRefs.add(tableRef); } return tableRefs; diff --git a/src/sqlancer/clickhouse/gen/ClickHouseExpressionGenerator.java b/src/sqlancer/clickhouse/gen/ClickHouseExpressionGenerator.java index 1054055f9..4e04e91f4 100644 --- a/src/sqlancer/clickhouse/gen/ClickHouseExpressionGenerator.java +++ b/src/sqlancer/clickhouse/gen/ClickHouseExpressionGenerator.java @@ -3,51 +3,196 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import com.clickhouse.client.ClickHouseDataType; -import ru.yandex.clickhouse.domain.ClickHouseDataType; import sqlancer.Randomly; import sqlancer.clickhouse.ClickHouseProvider.ClickHouseGlobalState; import sqlancer.clickhouse.ClickHouseSchema; import sqlancer.clickhouse.ClickHouseSchema.ClickHouseColumn; import sqlancer.clickhouse.ClickHouseSchema.ClickHouseLancerDataType; +import sqlancer.clickhouse.ClickHouseSchema.ClickHouseTable; import sqlancer.clickhouse.ast.ClickHouseAggregate; +import sqlancer.clickhouse.ast.ClickHouseAggregate.ClickHouseAggregateFunction; +import sqlancer.clickhouse.ast.ClickHouseAliasOperation; +import sqlancer.clickhouse.ast.ClickHouseBinaryArithmeticOperation; import sqlancer.clickhouse.ast.ClickHouseBinaryComparisonOperation; +import sqlancer.clickhouse.ast.ClickHouseBinaryFunctionOperation; import sqlancer.clickhouse.ast.ClickHouseBinaryLogicalOperation; import sqlancer.clickhouse.ast.ClickHouseColumnReference; -import sqlancer.clickhouse.ast.ClickHouseConstant; import sqlancer.clickhouse.ast.ClickHouseExpression; +import sqlancer.clickhouse.ast.ClickHouseExpression.ClickHouseJoin; +import sqlancer.clickhouse.ast.ClickHouseSelect; +import sqlancer.clickhouse.ast.ClickHouseTableReference; +import sqlancer.clickhouse.ast.ClickHouseUnaryFunctionOperation; import sqlancer.clickhouse.ast.ClickHouseUnaryPostfixOperation; import sqlancer.clickhouse.ast.ClickHouseUnaryPostfixOperation.ClickHouseUnaryPostfixOperator; import sqlancer.clickhouse.ast.ClickHouseUnaryPrefixOperation; import sqlancer.clickhouse.ast.ClickHouseUnaryPrefixOperation.ClickHouseUnaryPrefixOperator; +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; import sqlancer.common.gen.TypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; public class ClickHouseExpressionGenerator - extends TypedExpressionGenerator { + extends TypedExpressionGenerator implements + NoRECGenerator, + TLPWhereGenerator { private final ClickHouseGlobalState globalState; public boolean allowAggregateFunctions; + private List tables; + private final List columnRefs; + public ClickHouseExpressionGenerator(ClickHouseGlobalState globalState) { this.globalState = globalState; + this.columnRefs = new ArrayList<>(); + } + + public final void addColumns(List col) { + this.columnRefs.addAll(col); + } + + private enum ColumnLike { + UNARY_PREFIX, BINARY_ARITHMETIC, UNARY_FUNCTION, BINARY_FUNCTION } private enum Expression { - UNARY_POSTFIX, UNARY_PREFIX, BINARY_COMPARISON, BINARY_LOGICAL + UNARY_PREFIX, BINARY_ARITHMETIC, UNARY_FUNCTION, BINARY_FUNCTION, BINARY_LOGICAL, BINARY_COMPARISON, + UNARY_POSTFIX + } + + public ClickHouseExpression generateExpressionWithColumns(List columns, + int remainingDepth) { + if (columns.isEmpty() || remainingDepth <= 2 && Randomly.getBooleanWithRatherLowProbability()) { + return generateConstant(null); + } + + if (remainingDepth <= 2 || Randomly.getBooleanWithRatherLowProbability()) { + return columns.get((int) Randomly.getNotCachedInteger(0, columns.size() - 1)); + } + + ColumnLike expr = Randomly.fromOptions(ColumnLike.values()); + switch (expr) { + case UNARY_PREFIX: + return new ClickHouseUnaryPrefixOperation(generateExpressionWithColumns(columns, remainingDepth - 1), + ClickHouseUnaryPrefixOperator.MINUS); + case BINARY_ARITHMETIC: + return new ClickHouseBinaryArithmeticOperation(generateExpressionWithColumns(columns, remainingDepth - 1), + generateExpressionWithColumns(columns, remainingDepth - 1), + ClickHouseBinaryArithmeticOperation.ClickHouseBinaryArithmeticOperator.getRandom()); + case UNARY_FUNCTION: + return new ClickHouseUnaryFunctionOperation(generateExpressionWithColumns(columns, remainingDepth - 1), + ClickHouseUnaryFunctionOperation.ClickHouseUnaryFunctionOperator.getRandom()); + case BINARY_FUNCTION: + return new ClickHouseBinaryFunctionOperation(generateExpressionWithColumns(columns, remainingDepth - 1), + generateExpressionWithColumns(columns, remainingDepth - 1), + ClickHouseBinaryFunctionOperation.ClickHouseBinaryFunctionOperator.getRandom()); + default: + throw new AssertionError(expr); + } + } + + public ClickHouseExpression generateAggregateExpressionWithColumns(List columns, + int remainingDepth) { + if (Randomly.getBooleanWithRatherLowProbability()) { + return new ClickHouseAggregate(generateExpressionWithColumns(columns, remainingDepth - 1), + ClickHouseAggregate.ClickHouseAggregateFunction.getRandom()); + } + if (columns.isEmpty() || remainingDepth <= 2 && Randomly.getBooleanWithRatherLowProbability()) { + return generateConstant(null); + } + + if (remainingDepth <= 2 || Randomly.getBooleanWithRatherLowProbability()) { + return columns.get((int) Randomly.getNotCachedInteger(0, columns.size() - 1)); + } + + ColumnLike expr = Randomly.fromOptions(ColumnLike.values()); + switch (expr) { + case UNARY_PREFIX: + return new ClickHouseUnaryPrefixOperation(generateExpressionWithColumns(columns, remainingDepth - 1), + ClickHouseUnaryPrefixOperator.MINUS); + case BINARY_ARITHMETIC: + return new ClickHouseBinaryArithmeticOperation(generateExpressionWithColumns(columns, remainingDepth - 1), + generateExpressionWithColumns(columns, remainingDepth - 1), + ClickHouseBinaryArithmeticOperation.ClickHouseBinaryArithmeticOperator.getRandom()); + case UNARY_FUNCTION: + return new ClickHouseUnaryFunctionOperation(generateExpressionWithColumns(columns, remainingDepth - 1), + ClickHouseUnaryFunctionOperation.ClickHouseUnaryFunctionOperator.getRandom()); + case BINARY_FUNCTION: + return new ClickHouseBinaryFunctionOperation(generateExpressionWithColumns(columns, remainingDepth - 1), + generateExpressionWithColumns(columns, remainingDepth - 1), + ClickHouseBinaryFunctionOperation.ClickHouseBinaryFunctionOperator.getRandom()); + default: + throw new AssertionError(expr); + } + } + + public ClickHouseExpression generateExpressionWithExpression(List expression, + int remainingDepth) { + if (remainingDepth <= 2 || Randomly.getBooleanWithRatherLowProbability()) { + if (Randomly.getBoolean()) { + return expression.get((int) Randomly.getNotCachedInteger(0, expression.size() - 1)); + } else { + return generateConstant(null); + } + } + + Expression type = Randomly.fromOptions(Expression.values()); + switch (type) { + case UNARY_PREFIX: + return new ClickHouseUnaryPrefixOperation(generateExpressionWithExpression(expression, remainingDepth - 1), + ClickHouseUnaryPrefixOperation.ClickHouseUnaryPrefixOperator.getRandom()); + case UNARY_POSTFIX: + return new ClickHouseUnaryPostfixOperation(generateExpressionWithExpression(expression, remainingDepth - 1), + ClickHouseUnaryPostfixOperation.ClickHouseUnaryPostfixOperator.getRandom(), false); + case BINARY_COMPARISON: + return new ClickHouseBinaryComparisonOperation( + generateExpressionWithExpression(expression, remainingDepth - 1), + generateExpressionWithExpression(expression, remainingDepth - 1), + ClickHouseBinaryComparisonOperation.ClickHouseBinaryComparisonOperator.getRandomOperator()); + case BINARY_LOGICAL: + return new ClickHouseBinaryLogicalOperation( + generateExpressionWithExpression(expression, remainingDepth - 1), + generateExpressionWithExpression(expression, remainingDepth - 1), + ClickHouseBinaryLogicalOperation.ClickHouseBinaryLogicalOperator.getRandom()); + case BINARY_ARITHMETIC: + return new ClickHouseBinaryArithmeticOperation( + generateExpressionWithExpression(expression, remainingDepth - 1), + generateExpressionWithExpression(expression, remainingDepth - 1), + ClickHouseBinaryArithmeticOperation.ClickHouseBinaryArithmeticOperator.getRandom()); + case UNARY_FUNCTION: + return new ClickHouseUnaryFunctionOperation( + generateExpressionWithExpression(expression, remainingDepth - 1), + ClickHouseUnaryFunctionOperation.ClickHouseUnaryFunctionOperator.getRandom()); + case BINARY_FUNCTION: + return new ClickHouseBinaryFunctionOperation( + generateExpressionWithExpression(expression, remainingDepth - 1), + generateExpressionWithExpression(expression, remainingDepth - 1), + ClickHouseBinaryFunctionOperation.ClickHouseBinaryFunctionOperator.getRandom()); + default: + throw new AssertionError(type); + } } @Override protected ClickHouseExpression generateExpression(ClickHouseLancerDataType type, int depth) { - if (allowAggregateFunctions && Randomly.getBoolean()) { - return generateAggregate(); + if (allowAggregateFunctions && Randomly.getBooleanWithRatherLowProbability()) { + ClickHouseLancerDataType aggType = ClickHouseLancerDataType.getRandom(); + return new ClickHouseAggregate(generateExpression(aggType, depth + 1), + ClickHouseAggregate.ClickHouseAggregateFunction.getRandom()); } - if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() + || Randomly.getBooleanWithRatherLowProbability()) { return generateLeafNode(type); } Expression expr = Randomly.fromOptions(Expression.values()); ClickHouseLancerDataType leftLeafType = ClickHouseLancerDataType.getRandom(); ClickHouseLancerDataType rightLeafType = ClickHouseLancerDataType.getRandom(); - if (Randomly.getBoolean()) { + if (Randomly.getBooleanWithRatherLowProbability()) { rightLeafType = leftLeafType; } @@ -66,21 +211,52 @@ protected ClickHouseExpression generateExpression(ClickHouseLancerDataType type, return new ClickHouseBinaryLogicalOperation(generateExpression(leftLeafType, depth + 1), generateExpression(rightLeafType, depth + 1), ClickHouseBinaryLogicalOperation.ClickHouseBinaryLogicalOperator.getRandom()); + case BINARY_ARITHMETIC: + return new ClickHouseBinaryArithmeticOperation(generateExpression(leftLeafType, depth + 1), + generateExpression(leftLeafType, depth + 1), + ClickHouseBinaryArithmeticOperation.ClickHouseBinaryArithmeticOperator.getRandom()); + case UNARY_FUNCTION: + return new ClickHouseUnaryFunctionOperation(generateExpression(leftLeafType, depth + 1), + ClickHouseUnaryFunctionOperation.ClickHouseUnaryFunctionOperator.getRandom()); + case BINARY_FUNCTION: + return new ClickHouseBinaryFunctionOperation(generateExpression(leftLeafType, depth + 1), + generateExpression(leftLeafType, depth + 1), + ClickHouseBinaryFunctionOperation.ClickHouseBinaryFunctionOperator.getRandom()); default: throw new AssertionError(expr); } } + protected ClickHouseExpression.ClickHouseJoinOnClause generateJoinClause(ClickHouseTableReference leftTable, + ClickHouseTableReference rightTable) { + List leftColumns = leftTable.getColumnReferences(); + List rightColumns = rightTable.getColumnReferences(); + ClickHouseExpression leftExpr = generateExpressionWithColumns(leftColumns, 2); + ClickHouseExpression rightExpr = generateExpressionWithColumns(rightColumns, 2); + return new ClickHouseExpression.ClickHouseJoinOnClause(leftExpr, rightExpr); + } + @Override protected ClickHouseExpression generateColumn(ClickHouseLancerDataType type) { - if (columns.isEmpty()) { + if (columnRefs.isEmpty()) { return generateConstant(type); } - List filteredColumns = columns.stream() - .filter(c -> c.getType().getType().name().equals(type.getType().name())).collect(Collectors.toList()); - ClickHouseColumn column = filteredColumns.isEmpty() ? Randomly.fromList(columns) - : Randomly.fromList(filteredColumns); - return new ClickHouseColumnReference(column, null); + List filteredColumns = columnRefs.stream() + .filter(c -> c.getColumn().getType().getType().name().equals(type.getType().name())) + .collect(Collectors.toList()); + return filteredColumns.isEmpty() ? Randomly.fromList(columnRefs) : Randomly.fromList(filteredColumns); + } + + protected ClickHouseExpression getColumnNameFromTable(ClickHouseSchema.ClickHouseTable table) { + if (columnRefs.isEmpty()) { + return generateConstant(ClickHouseLancerDataType.getRandom()); + } + List filteredColumns = columnRefs.stream() + .filter(c -> c.getColumn().getTable() == table).collect(Collectors.toList()); + if (filteredColumns.isEmpty()) { + return generateConstant(ClickHouseLancerDataType.getRandom()); + } + return Randomly.fromList(filteredColumns); } @Override @@ -88,29 +264,29 @@ protected ClickHouseLancerDataType getRandomType() { return ClickHouseLancerDataType.getRandom(); } - public List getRandomJoinClauses( + public List getRandomJoinClauses(ClickHouseTableReference left, List tables) { List joinStatements = new ArrayList<>(); if (!globalState.getDbmsSpecificOptions().testJoins) { return joinStatements; } - if (Randomly.getBoolean() && tables.size() > 1) { + List leftTables = new ArrayList<>(); + leftTables.add(left); + if (Randomly.getBoolean() && !tables.isEmpty()) { int nrJoinClauses = (int) Randomly.getNotCachedInteger(0, tables.size()); for (int i = 0; i < nrJoinClauses; i++) { - ClickHouseExpression joinClause = generateExpression(ClickHouseLancerDataType.getRandom()); - ClickHouseSchema.ClickHouseTable table = Randomly.fromList(tables); - tables.remove(table); - ClickHouseExpression.ClickHouseJoin.JoinType options; - options = Randomly.fromOptions(ClickHouseExpression.ClickHouseJoin.JoinType.values()); - if (options == ClickHouseExpression.ClickHouseJoin.JoinType.NATURAL) { - // NATURAL joins do not have an ON clause - joinClause = null; - } - ClickHouseExpression.ClickHouseJoin j = new ClickHouseExpression.ClickHouseJoin(table, joinClause, - options); + ClickHouseTableReference leftTable = leftTables + .get((int) Randomly.getNotCachedInteger(0, leftTables.size() - 1)); + ClickHouseTableReference rightTable = new ClickHouseTableReference(Randomly.fromList(tables), + "right_" + i); + ClickHouseExpression.ClickHouseJoinOnClause joinClause = generateJoinClause(leftTable, rightTable); + ClickHouseExpression.ClickHouseJoin.JoinType options = Randomly + .fromOptions(ClickHouseExpression.ClickHouseJoin.JoinType.values()); + ClickHouseExpression.ClickHouseJoin j = new ClickHouseExpression.ClickHouseJoin(leftTable, rightTable, + options, joinClause); joinStatements.add(j); + leftTables.add(rightTable); } - } return joinStatements; } @@ -121,7 +297,8 @@ protected boolean canGenerateColumnOfType(ClickHouseLancerDataType type) { } @Override - public ClickHouseExpression generateConstant(ClickHouseLancerDataType type) { + public ClickHouseExpression generateConstant(ClickHouseLancerDataType genType) { + ClickHouseLancerDataType type = (genType == null) ? ClickHouseLancerDataType.getRandom() : genType; switch (type.getType()) { case Int8: case UInt8: @@ -131,33 +308,30 @@ public ClickHouseExpression generateConstant(ClickHouseLancerDataType type) { case UInt32: case Int64: case UInt64: - return ClickHouseConstant.createIntConstant(type.getType(), globalState.getRandomly().getInteger()); + return ClickHouseCreateConstant.createIntConstant(type.getType(), globalState.getRandomly().getInteger()); case Float32: - return ClickHouseConstant.createFloat32Constant((float) globalState.getRandomly().getDouble()); + return ClickHouseCreateConstant.createFloat32Constant((float) globalState.getRandomly().getDouble()); case Float64: - return ClickHouseConstant.createFloat64Constant(globalState.getRandomly().getDouble()); + return ClickHouseCreateConstant.createFloat64Constant(globalState.getRandomly().getDouble()); case String: - return ClickHouseConstant.createStringConstant(globalState.getRandomly().getString()); + return ClickHouseCreateConstant.createStringConstant(globalState.getRandomly().getString()); default: throw new AssertionError(); } } public ClickHouseExpression getHavingClause() { - allowAggregateFunctions = true; - return generateExpression(new ClickHouseLancerDataType(ClickHouseDataType.UInt8)); + return generateAggregate(); } public ClickHouseAggregate generateArgsForAggregate(ClickHouseDataType dataType, ClickHouseAggregate.ClickHouseAggregateFunction agg) { - List types = agg.getTypes(dataType); - List args = new ArrayList<>(); - for (ClickHouseDataType argType : types) { - this.allowAggregateFunctions = false; - args.add(generateExpression(new ClickHouseLancerDataType(argType))); - this.allowAggregateFunctions = true; - } - return new ClickHouseAggregate(args, agg); + ClickHouseDataType type = agg.getType(dataType); + this.allowAggregateFunctions = false; + ClickHouseExpression arg = generateExpression(new ClickHouseLancerDataType(type)); + this.allowAggregateFunctions = true; + + return new ClickHouseAggregate(arg, agg); } public ClickHouseExpressionGenerator allowAggregates(boolean value) { @@ -166,19 +340,12 @@ public ClickHouseExpressionGenerator allowAggregates(boolean value) { } public ClickHouseExpression generateAggregate() { - return getAggregate(ClickHouseLancerDataType.getRandom().getType()); - } - - private ClickHouseExpression getAggregate(ClickHouseDataType dataType) { - List aggregates = ClickHouseAggregate.ClickHouseAggregateFunction - .getAggregates(dataType); - ClickHouseAggregate.ClickHouseAggregateFunction agg = Randomly.fromList(aggregates); - return generateArgsForAggregate(dataType, agg); + return generateAggregateExpressionWithColumns(columnRefs, 3); } @Override public ClickHouseExpression generatePredicate() { - return generateExpression(new ClickHouseSchema.ClickHouseLancerDataType(ClickHouseDataType.UInt8)); + return generateExpressionWithColumns(columnRefs, 3); } @Override @@ -190,4 +357,94 @@ public ClickHouseExpression negatePredicate(ClickHouseExpression predicate) { public ClickHouseExpression isNull(ClickHouseExpression expr) { return new ClickHouseUnaryPostfixOperation(expr, ClickHouseUnaryPostfixOperator.IS_NULL, false); } + + @Override + public ClickHouseExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.tables = tables.getTables(); + this.columns = tables.getColumns(); + return this; + } + + @Override + public ClickHouseExpression generateBooleanExpression() { + List columnRefs = columns.stream() + .map(c -> c.asColumnReference(c.getTable().getName())).collect(Collectors.toList()); + return generateExpressionWithColumns(columnRefs, 5); + } + + @Override + public ClickHouseSelect generateSelect() { + return new ClickHouseSelect(); + } + + @Override + public List getRandomJoinClauses() { + List joinStatements = new ArrayList<>(); + if (globalState.getClickHouseOptions().testJoins && Randomly.getBoolean()) { + return joinStatements; + } + List leftTables = new ArrayList<>(); + leftTables.add(new ClickHouseTableReference(tables.get(0), null)); + if (Randomly.getBoolean() && !tables.isEmpty()) { + int nrJoinClauses = (int) Randomly.getNotCachedInteger(0, tables.size()); + for (int i = 0; i < nrJoinClauses; i++) { + ClickHouseTableReference leftTable = leftTables + .get((int) Randomly.getNotCachedInteger(0, leftTables.size() - 1)); + ClickHouseTableReference rightTable = new ClickHouseTableReference(Randomly.fromList(tables), + "right_" + i); + ClickHouseExpression.ClickHouseJoinOnClause joinClause = generateJoinClause(leftTable, rightTable); + ClickHouseExpression.ClickHouseJoin.JoinType options = Randomly + .fromOptions(ClickHouseExpression.ClickHouseJoin.JoinType.values()); + ClickHouseExpression.ClickHouseJoin j = new ClickHouseExpression.ClickHouseJoin(leftTable, rightTable, + options, joinClause); + joinStatements.add(j); + leftTables.add(rightTable); + } + } + return joinStatements; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new ClickHouseTableReference(t, null)).collect(Collectors.toList()); + } + + @Override + public String generateOptimizedQueryString(ClickHouseSelect select, ClickHouseExpression whereCondition, + boolean shouldUseAggregate) { + List filteredColumns = Randomly.extractNrRandomColumns(columns, + (int) Randomly.getNotCachedInteger(1, columns.size())); + if (shouldUseAggregate) { + ClickHouseAggregate aggr = new ClickHouseAggregate( + new ClickHouseColumnReference(ClickHouseColumn.createDummy("*", null), null, null), + ClickHouseAggregateFunction.COUNT); + select.setFetchColumns(List.of(aggr)); + } else { + select.setFetchColumns(filteredColumns.stream().map(c -> c.asColumnReference(c.getTable().getName())) + .collect(Collectors.toList())); + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(ClickHouseSelect select, ClickHouseExpression whereCondition) { + ClickHouseExpression inner = new ClickHouseAliasOperation(whereCondition, "check"); + + select.setFetchColumns(List.of(inner)); + select.setWhereClause(null); + return "SELECT SUM(check <> 0) FROM (" + select.asString() + ") as res"; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy) { + return List.of(new ClickHouseColumnReference(ClickHouseColumn.createDummy("*", null), null, null)); + } + List columnReferences = columns.stream() + .map(c -> c.asColumnReference(c.getTable().getName())).collect(Collectors.toList()); + return IntStream.range(0, 1 + Randomly.smallNumber()) + .mapToObj(i -> generateExpressionWithColumns(columnReferences, 5)).collect(Collectors.toList()); + } } diff --git a/src/sqlancer/clickhouse/gen/ClickHouseTableGenerator.java b/src/sqlancer/clickhouse/gen/ClickHouseTableGenerator.java index 4d6f97008..3df5c5524 100644 --- a/src/sqlancer/clickhouse/gen/ClickHouseTableGenerator.java +++ b/src/sqlancer/clickhouse/gen/ClickHouseTableGenerator.java @@ -2,8 +2,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; -import ru.yandex.clickhouse.domain.ClickHouseDataType; import sqlancer.Randomly; import sqlancer.clickhouse.ClickHouseErrors; import sqlancer.clickhouse.ClickHouseProvider; @@ -37,7 +37,7 @@ public static SQLQueryAdapter createTableStatement(String tableName, ClickHouseTableGenerator chTableGenerator = new ClickHouseTableGenerator(tableName, globalState); chTableGenerator.start(); ExpectedErrors errors = new ExpectedErrors(); - ClickHouseErrors.addTableManipulationErrors(errors); + ClickHouseErrors.addExpectedExpressionErrors(errors); return new SQLQueryAdapter(chTableGenerator.sb.toString(), errors, true); } @@ -55,7 +55,7 @@ public void start() { sb.append(" ("); int nrColumns = 1 + Randomly.smallNumber(); for (int i = 0; i < nrColumns; i++) { - columns.add(ClickHouseSchema.ClickHouseColumn.createDummy(ClickHouseCommon.createColumnName(i))); + columns.add(ClickHouseSchema.ClickHouseColumn.createDummy(ClickHouseCommon.createColumnName(i), null)); } for (int i = 0; i < nrColumns; i++) { if (i != 0) { @@ -79,8 +79,8 @@ public void start() { if (engine == ClickHouseEngine.MergeTree) { if (Randomly.getBoolean()) { sb.append(" ORDER BY "); - ClickHouseExpression expr = gen - .generateExpression(ClickHouseSchema.ClickHouseLancerDataType.getRandom()); + ClickHouseExpression expr = gen.generateExpressionWithColumns( + columns.stream().map(c -> c.asColumnReference(null)).collect(Collectors.toList()), 3); sb.append(ClickHouseToStringVisitor.asString(expr)); } else { sb.append(" ORDER BY tuple() "); @@ -88,16 +88,18 @@ public void start() { if (Randomly.getBoolean()) { sb.append(" PARTITION BY "); - ClickHouseExpression expr = gen - .generateExpression(ClickHouseSchema.ClickHouseLancerDataType.getRandom()); + ClickHouseExpression expr = gen.generateExpressionWithColumns( + columns.stream().map(c -> c.asColumnReference(null)).collect(Collectors.toList()), 3); sb.append(ClickHouseToStringVisitor.asString(expr)); } if (Randomly.getBoolean()) { sb.append(" SAMPLE BY "); - ClickHouseExpression expr = gen - .generateExpression(ClickHouseSchema.ClickHouseLancerDataType.getRandom()); + ClickHouseExpression expr = gen.generateExpressionWithColumns( + columns.stream().map(c -> c.asColumnReference(null)).collect(Collectors.toList()), 3); sb.append(ClickHouseToStringVisitor.asString(expr)); } + // Suppress index sanity checks https://github.com/sqlancer/sqlancer/issues/788 + sb.append(" SETTINGS allow_suspicious_indices=1"); // TODO: PRIMARY KEY } @@ -109,8 +111,8 @@ private void addColumnsConstraint(ClickHouseExpressionGenerator gen) { sb.append(" CONSTRAINT "); sb.append(ClickHouseCommon.createConstraintName(i)); sb.append(" CHECK "); - ClickHouseExpression expr = gen - .generateExpression(new ClickHouseSchema.ClickHouseLancerDataType(ClickHouseDataType.UInt8)); + ClickHouseExpression expr = gen.generateExpressionWithColumns( + columns.stream().map(c -> c.asColumnReference(null)).collect(Collectors.toList()), 2); sb.append(ClickHouseToStringVisitor.asString(expr)); } } diff --git a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPAggregateOracle.java b/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPAggregateOracle.java index 7bf1f01a7..cedac9c51 100644 --- a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPAggregateOracle.java +++ b/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPAggregateOracle.java @@ -3,66 +3,62 @@ import java.sql.SQLException; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; -import ru.yandex.clickhouse.domain.ClickHouseDataType; import sqlancer.ComparatorHelper; import sqlancer.Randomly; import sqlancer.clickhouse.ClickHouseErrors; import sqlancer.clickhouse.ClickHouseProvider; -import sqlancer.clickhouse.ClickHouseSchema; import sqlancer.clickhouse.ClickHouseVisitor; import sqlancer.clickhouse.ast.ClickHouseAggregate; -import sqlancer.clickhouse.ast.ClickHouseExpression; -import sqlancer.clickhouse.ast.ClickHouseSelect; -import sqlancer.clickhouse.ast.ClickHouseUnaryPostfixOperation; -import sqlancer.clickhouse.ast.ClickHouseUnaryPrefixOperation; -import sqlancer.clickhouse.gen.ClickHouseCommon; -import sqlancer.clickhouse.gen.ClickHouseExpressionGenerator; +import sqlancer.clickhouse.ast.ClickHouseAliasOperation; public class ClickHouseTLPAggregateOracle extends ClickHouseTLPBase { - private ClickHouseExpressionGenerator gen; - public ClickHouseTLPAggregateOracle(ClickHouseProvider.ClickHouseGlobalState state) { super(state); ClickHouseErrors.addExpectedExpressionErrors(errors); - ClickHouseErrors.addQueryErrors(errors); } @Override public void check() throws SQLException { - ClickHouseSchema s = state.getSchema(); - ClickHouseSchema.ClickHouseTables targetTables = s.getRandomTableNonEmptyTables(); - gen = new ClickHouseExpressionGenerator(state).setColumns(targetTables.getColumns()); - ClickHouseSelect select = new ClickHouseSelect(); + super.check(); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(IntStream.range(0, 1 + Randomly.smallNumber()) + .mapToObj(i -> gen.generateExpressionWithColumns(columns, 5)).collect(Collectors.toList())); + } + ClickHouseAggregate.ClickHouseAggregateFunction windowFunction = Randomly.fromOptions( ClickHouseAggregate.ClickHouseAggregateFunction.MIN, ClickHouseAggregate.ClickHouseAggregateFunction.MAX, ClickHouseAggregate.ClickHouseAggregateFunction.SUM); - ClickHouseAggregate aggregate = new ClickHouseAggregate( - gen.generateExpressions(ClickHouseSchema.ClickHouseLancerDataType.getRandom(), 1), windowFunction); + + ClickHouseAggregate aggregate = new ClickHouseAggregate(gen.generateExpressionWithColumns(columns, 6), + windowFunction); select.setFetchColumns(Arrays.asList(aggregate)); - List from = ClickHouseCommon.getTableRefs(targetTables.getTables(), s); - select.setFromList(from); - if (Randomly.getBoolean()) { - select.setOrderByExpressions(gen.generateOrderBys()); - } + String originalQuery = ClickHouseVisitor.asString(select); originalQuery += " SETTINGS aggregate_functions_null_for_empty = 1"; - ClickHouseExpression whereClause = gen - .generateExpression(new ClickHouseSchema.ClickHouseLancerDataType(ClickHouseDataType.UInt8)); - ClickHouseUnaryPrefixOperation negatedClause = new ClickHouseUnaryPrefixOperation(whereClause, - ClickHouseUnaryPrefixOperation.ClickHouseUnaryPrefixOperator.NOT); - ClickHouseUnaryPostfixOperation notNullClause = new ClickHouseUnaryPostfixOperation(whereClause, - ClickHouseUnaryPostfixOperation.ClickHouseUnaryPostfixOperator.IS_NULL, false); + select.setFetchColumns(Arrays.asList(new ClickHouseAliasOperation(aggregate, "aggr"))); + + select.setWhereClause(predicate); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setGroupByClause(IntStream.range(0, 1 + Randomly.smallNumber()) + .mapToObj(i -> gen.generateExpressionWithColumns(columns, 5)).collect(Collectors.toList())); + } + if (Randomly.getBoolean()) { + select.setOrderByClauses(IntStream.range(0, 1 + Randomly.smallNumber()) + .mapToObj(i -> gen.generateExpressionWithColumns(columns, 5)).collect(Collectors.toList())); + } - ClickHouseSelect leftSelect = getSelect(aggregate, from, whereClause); - ClickHouseSelect middleSelect = getSelect(aggregate, from, negatedClause); - ClickHouseSelect rightSelect = getSelect(aggregate, from, notNullClause); String metamorphicText = "SELECT " + aggregate.getFunc().toString() + "(aggr) FROM ("; - metamorphicText += ClickHouseVisitor.asString(leftSelect) + " UNION ALL " - + ClickHouseVisitor.asString(middleSelect) + " UNION ALL " + ClickHouseVisitor.asString(rightSelect); + metamorphicText += ClickHouseVisitor.asString(select) + " UNION ALL "; + select.setWhereClause(negatedPredicate); + metamorphicText += ClickHouseVisitor.asString(select) + " UNION ALL "; + select.setWhereClause(isNullPredicate); + metamorphicText += ClickHouseVisitor.asString(select); metamorphicText += ")"; metamorphicText += " SETTINGS aggregate_functions_null_for_empty = 1"; List firstResult = ComparatorHelper.getResultSetFirstColumnAsString(originalQuery, errors, state); @@ -75,9 +71,9 @@ public void check() throws SQLException { if (firstResult.size() != secondResult.size()) { throw new AssertionError(); - } else if (firstResult.isEmpty()) { + } else if (firstResult.isEmpty() || firstResult.equals(secondResult)) { return; - } else if (firstResult.size() == 1) { + } else if (firstResult.size() == 1 && secondResult.size() == 1) { if (firstResult.get(0).equals(secondResult.get(0))) { return; } else if (!ComparatorHelper.isEqualDouble(firstResult.get(0), secondResult.get(0))) { @@ -88,20 +84,4 @@ public void check() throws SQLException { } } - private ClickHouseSelect getSelect(ClickHouseAggregate aggregate, List from, - ClickHouseExpression whereClause) { - ClickHouseSelect leftSelect = new ClickHouseSelect(); - leftSelect.setFetchColumns( - Arrays.asList(new ClickHouseExpression.ClickHousePostfixText(aggregate, " as aggr", null))); - leftSelect.setFromList(from); - leftSelect.setWhereClause(whereClause); - if (Randomly.getBooleanWithRatherLowProbability()) { - leftSelect.setGroupByClause(gen.generateExpressions(Randomly.smallNumber() + 1)); - } - if (Randomly.getBoolean()) { - leftSelect.setOrderByExpressions(gen.generateOrderBys()); - } - return leftSelect; - } - } diff --git a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPBase.java b/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPBase.java index 045aadc51..9374442ff 100644 --- a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPBase.java +++ b/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPBase.java @@ -1,58 +1,76 @@ package sqlancer.clickhouse.oracle.tlp; +import static java.lang.Math.min; +import static java.util.stream.IntStream.range; + import java.sql.SQLException; import java.util.List; import java.util.stream.Collectors; +import sqlancer.ComparatorHelper; import sqlancer.Randomly; import sqlancer.clickhouse.ClickHouseErrors; import sqlancer.clickhouse.ClickHouseProvider.ClickHouseGlobalState; import sqlancer.clickhouse.ClickHouseSchema; import sqlancer.clickhouse.ClickHouseSchema.ClickHouseTable; -import sqlancer.clickhouse.ClickHouseSchema.ClickHouseTables; +import sqlancer.clickhouse.ClickHouseVisitor; import sqlancer.clickhouse.ast.ClickHouseColumnReference; import sqlancer.clickhouse.ast.ClickHouseExpression; import sqlancer.clickhouse.ast.ClickHouseExpression.ClickHouseJoin; import sqlancer.clickhouse.ast.ClickHouseSelect; -import sqlancer.clickhouse.gen.ClickHouseCommon; +import sqlancer.clickhouse.ast.ClickHouseTableReference; import sqlancer.clickhouse.gen.ClickHouseExpressionGenerator; import sqlancer.common.gen.ExpressionGenerator; import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; import sqlancer.common.oracle.TestOracle; public class ClickHouseTLPBase extends TernaryLogicPartitioningOracleBase - implements TestOracle { + implements TestOracle { - ClickHouseSchema s; - ClickHouseTables targetTables; + ClickHouseSchema schema; + List columns; ClickHouseExpressionGenerator gen; ClickHouseSelect select; public ClickHouseTLPBase(ClickHouseGlobalState state) { super(state); ClickHouseErrors.addExpectedExpressionErrors(errors); - ClickHouseErrors.addQueryErrors(errors); } @Override public void check() throws SQLException { - s = state.getSchema(); - targetTables = s.getRandomTableNonEmptyTables(); - gen = new ClickHouseExpressionGenerator(state).setColumns(targetTables.getColumns()); - initializeTernaryPredicateVariants(); + gen = new ClickHouseExpressionGenerator(state); + schema = state.getSchema(); select = new ClickHouseSelect(); - select.setFetchColumns(generateFetchColumns()); - List tables = targetTables.getTables(); - List joinStatements = gen.getRandomJoinClauses(tables); - List tableRefs = ClickHouseCommon.getTableRefs(tables, s); - select.setJoinClauses(joinStatements.stream().collect(Collectors.toList())); - select.setFromTables(tableRefs); + List tables = schema.getRandomTableNonEmptyTables().getTables(); + ClickHouseTableReference table = new ClickHouseTableReference( + tables.get((int) Randomly.getNotCachedInteger(0, tables.size())), + Randomly.getBoolean() ? "left" : null); + select.setFromClause(table); + columns = table.getColumnReferences(); + + if (state.getClickHouseOptions().testJoins && Randomly.getBoolean()) { + List joinStatements = gen.getRandomJoinClauses(table, tables); + columns.addAll(joinStatements.stream().flatMap(j -> j.getRightTable().getColumnReferences().stream()) + .collect(Collectors.toList())); + select.setJoinClauses(joinStatements); + } + gen.addColumns(columns); + int small = Randomly.smallNumber(); + List from = range(0, 1 + small) + .mapToObj(i -> gen.generateExpressionWithColumns(columns, 5)).collect(Collectors.toList()); + select.setFetchColumns(from); select.setWhereClause(null); + initializeTernaryPredicateVariants(); + // Smoke check + String query = ClickHouseVisitor.asString(select); + ComparatorHelper.getResultSetFirstColumnAsString(query, errors, state); } - List generateFetchColumns() { - return Randomly.nonEmptySubset(targetTables.getColumns()).stream() - .map(c -> new ClickHouseColumnReference(c, null)).collect(Collectors.toList()); + List generateFetchColumns(List columns) { + List list = Randomly.extractNrRandomColumns(columns, + min(1 + Randomly.smallNumber(), columns.size())); + return list.stream().map(c -> (ClickHouseExpression) c).collect(Collectors.toList()); } @Override diff --git a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPGroupByOracle.java b/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPGroupByOracle.java index bb3629604..e8881cb84 100644 --- a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPGroupByOracle.java +++ b/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPGroupByOracle.java @@ -4,12 +4,12 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.IntStream; import sqlancer.ComparatorHelper; import sqlancer.Randomly; import sqlancer.clickhouse.ClickHouseProvider; import sqlancer.clickhouse.ClickHouseVisitor; -import sqlancer.clickhouse.ast.ClickHouseColumnReference; import sqlancer.clickhouse.ast.ClickHouseExpression; public class ClickHouseTLPGroupByOracle extends ClickHouseTLPBase { @@ -21,7 +21,10 @@ public ClickHouseTLPGroupByOracle(ClickHouseProvider.ClickHouseGlobalState state @Override public void check() throws SQLException { super.check(); - select.setGroupByClause(select.getFetchColumns()); + List groupByColumns = IntStream.range(0, 1 + Randomly.smallNumber()) + .mapToObj(i -> gen.generateExpressionWithColumns(columns, 5)).collect(Collectors.toList()); + + select.setGroupByClause(groupByColumns); select.setWhereClause(null); String originalQueryString = ClickHouseVisitor.asString(select); @@ -34,18 +37,9 @@ public void check() throws SQLException { select.setWhereClause(isNullPredicate); String thirdQueryString = ClickHouseVisitor.asString(select); List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSetNoDuplicates(firstQueryString, - secondQueryString, thirdQueryString, combinedString, false, state, errors); + List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, + thirdQueryString, combinedString, true, state, errors); ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, state); } - - @Override - List generateFetchColumns() { - List columns; - columns = Randomly.nonEmptySubset(targetTables.getColumns()).stream() - .map(c -> new ClickHouseColumnReference(c, null)).collect(Collectors.toList()); - return columns; - } - } diff --git a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPHavingOracle.java b/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPHavingOracle.java index 2cc2410d1..a6f3e0cdb 100644 --- a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPHavingOracle.java +++ b/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPHavingOracle.java @@ -4,55 +4,51 @@ import java.util.HashSet; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.IntStream; import sqlancer.ComparatorHelper; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.clickhouse.ClickHouseErrors; import sqlancer.clickhouse.ClickHouseProvider; -import sqlancer.clickhouse.ClickHouseSchema; import sqlancer.clickhouse.ClickHouseVisitor; -import sqlancer.clickhouse.ast.ClickHouseColumnReference; +import sqlancer.clickhouse.ast.ClickHouseAggregate; import sqlancer.clickhouse.ast.ClickHouseExpression; import sqlancer.clickhouse.ast.ClickHouseSelect; import sqlancer.clickhouse.ast.ClickHouseUnaryPostfixOperation; import sqlancer.clickhouse.ast.ClickHouseUnaryPrefixOperation; -import sqlancer.clickhouse.gen.ClickHouseCommon; -import sqlancer.clickhouse.gen.ClickHouseExpressionGenerator; public class ClickHouseTLPHavingOracle extends ClickHouseTLPBase { public ClickHouseTLPHavingOracle(ClickHouseProvider.ClickHouseGlobalState state) { super(state); ClickHouseErrors.addExpectedExpressionErrors(errors); - ClickHouseErrors.addGroupingErrors(errors); } @Override public void check() throws SQLException { - ClickHouseSchema s = state.getSchema(); - ClickHouseSchema.ClickHouseTables targetTables = s.getRandomTableNonEmptyTables(); - List groupByColumns = Randomly.nonEmptySubset(targetTables.getColumns()).stream() - .map(c -> new ClickHouseColumnReference(c, null)).collect(Collectors.toList()); - List columns = targetTables.getColumns(); - ClickHouseExpressionGenerator gen = new ClickHouseExpressionGenerator(state).setColumns(columns); - ClickHouseExpressionGenerator aggrGen = new ClickHouseExpressionGenerator(state).allowAggregates(true) - .setColumns(columns); - ClickHouseSelect select = new ClickHouseSelect(); - select.setFetchColumns(aggrGen.generateExpressions(Randomly.smallNumber() + 1)); - List tables = targetTables.getTables(); - List joinStatements = gen.getRandomJoinClauses(tables); - List from = ClickHouseCommon.getTableRefs(tables, state.getSchema()); - select.setJoinClauses(joinStatements); + super.check(); + select.setFetchColumns(IntStream.range(0, Randomly.smallNumber() + 1) + .mapToObj(i -> gen.generateAggregateExpressionWithColumns(columns, 3)).collect(Collectors.toList())); select.setSelectType(ClickHouseSelect.SelectType.ALL); - select.setFromTables(from); // TODO order by? + + List groupByColumns = IntStream.range(0, 1 + Randomly.smallNumber()) + .mapToObj(i -> gen.generateExpressionWithColumns(columns, 6)).collect(Collectors.toList()); + select.setGroupByClause(groupByColumns); select.setHavingClause(null); String originalQueryString = ClickHouseVisitor.asString(select); + originalQueryString += " SETTINGS aggregate_functions_null_for_empty=1, enable_optimize_predicate_expression=0"; // https://github.com/ClickHouse/ClickHouse/issues/12264 List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - ClickHouseExpression predicate = aggrGen.getHavingClause(); + List aggregateExprs = select.getFetchColumns().stream() + .filter(p -> p instanceof ClickHouseAggregate).collect(Collectors.toList()); + if (aggregateExprs.isEmpty()) { + throw new IgnoreMeException(); + } + ClickHouseExpression predicate = gen.generateExpressionWithExpression(aggregateExprs, 6); select.setHavingClause(predicate); String firstQueryString = ClickHouseVisitor.asString(select); select.setHavingClause(new ClickHouseUnaryPrefixOperation(predicate, diff --git a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPWhereOracle.java b/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPWhereOracle.java deleted file mode 100644 index 5023b803a..000000000 --- a/src/sqlancer/clickhouse/oracle/tlp/ClickHouseTLPWhereOracle.java +++ /dev/null @@ -1,46 +0,0 @@ -package sqlancer.clickhouse.oracle.tlp; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.Randomly; -import sqlancer.clickhouse.ClickHouseErrors; -import sqlancer.clickhouse.ClickHouseProvider; -import sqlancer.clickhouse.ClickHouseVisitor; - -public class ClickHouseTLPWhereOracle extends ClickHouseTLPBase { - - public ClickHouseTLPWhereOracle(ClickHouseProvider.ClickHouseGlobalState state) { - super(state); - ClickHouseErrors.addExpectedExpressionErrors(errors); - ClickHouseErrors.addExpressionHavingErrors(errors); - } - - @Override - public void check() throws SQLException { - super.check(); - if (Randomly.getBooleanWithRatherLowProbability()) { - select.setOrderByExpressions(gen.generateOrderBys()); - } - String originalQueryString = ClickHouseVisitor.asString(select); - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - boolean orderBy = Randomly.getBooleanWithRatherLowProbability(); - if (orderBy) { - select.setOrderByExpressions(gen.generateOrderBys()); - } - select.setWhereClause(predicate); - String firstQueryString = ClickHouseVisitor.asString(select); - select.setWhereClause(negatedPredicate); - String secondQueryString = ClickHouseVisitor.asString(select); - select.setWhereClause(isNullPredicate); - String thirdQueryString = ClickHouseVisitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, !orderBy, state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state); - } -} diff --git a/src/sqlancer/cnosdb/CnosDBBugs.java b/src/sqlancer/cnosdb/CnosDBBugs.java new file mode 100644 index 000000000..4e6eb96e9 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBBugs.java @@ -0,0 +1,13 @@ +package sqlancer.cnosdb; + +public final class CnosDBBugs { + + // https://github.com/cnosdb/cnosdb/issues/786 + public static final boolean BUG786 = true; + + // https://github.com/apache/arrow-rs/issues/3547 + public static final boolean BUG3547 = true; + + private CnosDBBugs() { + } +} diff --git a/src/sqlancer/cnosdb/CnosDBComparatorHelper.java b/src/sqlancer/cnosdb/CnosDBComparatorHelper.java new file mode 100644 index 000000000..46b6ba615 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBComparatorHelper.java @@ -0,0 +1,145 @@ +package sqlancer.cnosdb; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.UnaryOperator; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.cnosdb.client.CnosDBResultSet; +import sqlancer.cnosdb.query.CnosDBSelectQuery; +import sqlancer.common.query.ExpectedErrors; + +public final class CnosDBComparatorHelper { + + private CnosDBComparatorHelper() { + } + + public static List getResultSetFirstColumnAsString(String queryString, ExpectedErrors errors, + CnosDBGlobalState state) throws Exception { + if (state.getOptions().logEachSelect()) { + // TODO: refactor me + state.getLogger().writeCurrent(queryString); + try { + state.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + CnosDBSelectQuery q = new CnosDBSelectQuery(queryString, errors); + List result = new ArrayList<>(); + CnosDBResultSet resultSet; + try { + q.executeAndGet(state); + resultSet = q.getResultSet(); + if (resultSet == null) { + throw new AssertionError(q); + } + while (resultSet.next()) { + result.add(resultSet.getString(1)); + } + } catch (Exception e) { + if (e instanceof IgnoreMeException) { + throw e; + } + if (e instanceof NumberFormatException) { + throw new IgnoreMeException(); + } + if (e.getMessage() == null) { + throw new AssertionError(queryString, e); + } + if (errors.errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } + throw new AssertionError(queryString, e); + } + + return result; + } + + public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, + String originalQueryString, List combinedString, CnosDBGlobalState state) { + if (resultSet.size() != secondResultSet.size()) { + String queryFormatString = "-- %s;\n-- cardinality: %d"; + String firstQueryString = String.format(queryFormatString, originalQueryString, resultSet.size()); + String secondQueryString = String.format(queryFormatString, String.join(";", combinedString), + secondResultSet.size()); + state.getState().getLocalState().log(String.format("%s\n%s", firstQueryString, secondQueryString)); + String assertionMessage = String.format("the size of the result sets mismatch (%d and %d)!\n%s\n%s", + resultSet.size(), secondResultSet.size(), firstQueryString, secondQueryString); + throw new AssertionError(assertionMessage); + } + + Set firstHashSet = new HashSet<>(resultSet); + Set secondHashSet = new HashSet<>(secondResultSet); + + if (!firstHashSet.equals(secondHashSet)) { + Set firstResultSetMisses = new HashSet<>(firstHashSet); + firstResultSetMisses.removeAll(secondHashSet); + Set secondResultSetMisses = new HashSet<>(secondHashSet); + secondResultSetMisses.removeAll(firstHashSet); + String queryFormatString = "-- %s;\n-- misses: %s"; + String firstQueryString = String.format(queryFormatString, originalQueryString, firstResultSetMisses); + String secondQueryString = String.format(queryFormatString, String.join(";", combinedString), + secondResultSetMisses); + // update the SELECT queries to be logged at the bottom of the error log file + state.getState().getLocalState().log(String.format("%s\n%s", firstQueryString, secondQueryString)); + String assertionMessage = String.format("the content of the result sets mismatch!\n%s\n%s", + firstQueryString, secondQueryString); + throw new AssertionError(assertionMessage); + } + } + + public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, + String originalQueryString, List combinedString, CnosDBGlobalState state, + UnaryOperator canonicalizationRule) { + // Overloaded version of assumeResultSetsAreEqual that takes a canonicalization function which is applied to + // both result sets before their comparison. + List canonicalizedResultSet = resultSet.stream().map(canonicalizationRule).collect(Collectors.toList()); + List canonicalizedSecondResultSet = secondResultSet.stream().map(canonicalizationRule) + .collect(Collectors.toList()); + assumeResultSetsAreEqual(canonicalizedResultSet, canonicalizedSecondResultSet, originalQueryString, + combinedString, state); + } + + public static List getCombinedResultSet(String firstQueryString, String secondQueryString, + String thirdQueryString, List combinedString, boolean asUnion, CnosDBGlobalState state, + ExpectedErrors errors) throws Exception { + List secondResultSet; + if (asUnion) { + String unionString = firstQueryString + " UNION ALL " + secondQueryString + " UNION ALL " + + thirdQueryString; + combinedString.add(unionString); + secondResultSet = getResultSetFirstColumnAsString(unionString, errors, state); + } else { + secondResultSet = new ArrayList<>(); + secondResultSet.addAll(getResultSetFirstColumnAsString(firstQueryString, errors, state)); + secondResultSet.addAll(getResultSetFirstColumnAsString(secondQueryString, errors, state)); + secondResultSet.addAll(getResultSetFirstColumnAsString(thirdQueryString, errors, state)); + combinedString.add(firstQueryString); + combinedString.add(secondQueryString); + combinedString.add(thirdQueryString); + } + return secondResultSet; + } + + public static List getCombinedResultSetNoDuplicates(String firstQueryString, String secondQueryString, + String thirdQueryString, List combinedString, boolean asUnion, CnosDBGlobalState state, + ExpectedErrors errors) throws Exception { + String unionString; + if (asUnion) { + unionString = firstQueryString + " UNION " + secondQueryString + " UNION " + thirdQueryString; + } else { + unionString = "SELECT DISTINCT * FROM (" + firstQueryString + " UNION ALL " + secondQueryString + + " UNION ALL " + thirdQueryString + ")"; + } + List secondResultSet; + combinedString.add(unionString); + secondResultSet = getResultSetFirstColumnAsString(unionString, errors, state); + return secondResultSet; + } +} diff --git a/src/sqlancer/cnosdb/CnosDBCompoundDataType.java b/src/sqlancer/cnosdb/CnosDBCompoundDataType.java new file mode 100644 index 000000000..034f0fc90 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBCompoundDataType.java @@ -0,0 +1,20 @@ +package sqlancer.cnosdb; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public final class CnosDBCompoundDataType { + + private final CnosDBDataType dataType; + + private CnosDBCompoundDataType(CnosDBDataType dataType) { + this.dataType = dataType; + } + + public static CnosDBCompoundDataType create(CnosDBDataType type) { + return new CnosDBCompoundDataType(type); + } + + public CnosDBDataType getDataType() { + return dataType; + } +} diff --git a/src/sqlancer/cnosdb/CnosDBExpectedError.java b/src/sqlancer/cnosdb/CnosDBExpectedError.java new file mode 100644 index 000000000..61dba101b --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBExpectedError.java @@ -0,0 +1,87 @@ +package sqlancer.cnosdb; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class CnosDBExpectedError { + + private CnosDBExpectedError() { + } + + public static List getExpectedErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("have the same name. Consider aliasing"); + errors.add( + "error: Optimizer rule 'projection_push_down' failed due to unexpected error: Schema error: Schema contains duplicate qualified field name"); + errors.add("Projection references non-aggregate values:"); + errors.add("External err: Schema error: No field named"); + errors.add( + "Optimizer rule 'common_sub_expression_eliminate' failed due to unexpected error: Schema error: No field named"); + errors.add("Binary"); + errors.add("Invalid pattern in LIKE expression"); + errors.add("If the projection contains the time column, it must contain the field column."); + errors.add("Schema error: No field named"); + errors.add("Optimizer rule 'simplify_expressions' failed due to unexpected error:"); + errors.add("err: Internal error: Optimizer rule 'projection_push_down' failed due to unexpected error"); + errors.add("Schema error: No field named "); + errors.add("err: External err: Schema error: No field named"); + errors.add("Optimizer rule 'simplify_expressions' failed due to unexpected error"); + errors.add("Csv error: CSV Writer does not support List"); + errors.add("This feature is not implemented: cross join."); + errors.add("Execution error: field position must be greater than zero"); + errors.add("First argument of `DATE_PART` must be non-null scalar Utf8"); + errors.add("Cannot create filter with non-boolean predicate 'NULL' returning Null"); + errors.add("requested character too large for encoding."); + errors.add("Can not find compatible types to compare Boolean with [Utf8]."); + errors.add("Cannot create filter with non-boolean predicate 'APPROXDISTINCT"); + errors.add("HAVING clause references non-aggregate values:"); + errors.add("Cannot create filter with non-boolean predicate"); + errors.add("negative substring length not allowed"); + errors.add("The function Sum does not support inputs of type Boolean."); + errors.add("The function Avg does not support inputs of type Boolean."); + errors.add("Percentile value must be between 0.0 and 1.0 inclusive"); + errors.add("Date part '' not supported"); + errors.add("Min/Max accumulator not implemented for type Boolean."); + errors.add("meta need get_series_id_by_filter"); + errors.add("Arrow: Cast error:"); + errors.add("Arrow error: Cast error:"); + errors.add("Datafusion: Execution error: Arrow error: External error: Arrow error: Cast error:"); + errors.add("Arrow error: Divide by zero error"); + errors.add("desired percentile argument must be float literal"); + errors.add("Unsupported CAST from Int32 to Timestamp(Nanosecond, None)"); + errors.add("Execution error: Date part"); + errors.add("Physical plan does not support logical expression MIN(Boolean"); + errors.add("The percentile argument for ApproxPercentileCont must be Float64, not Int64"); + errors.add("The percentile argument for ApproxPercentileContWithWeight must be Float64, not Int64."); + errors.add("Data type UInt64 not supported for binary operation '#' on dyn arrays."); + errors.add("Arrow: Divide by zero error"); + errors.add("The function ApproxPercentileCont does not support inputs of type Null."); + errors.add("can't be evaluated because there isn't a common type to coerce the types to"); + errors.add("This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug"); + errors.add("The function ApproxMedian does not support inputs of type Null."); + errors.add("null character not permitted."); + errors.add("The percentile argument for ApproxPercentileCont must be Float64, not Null."); + errors.add("This feature is not implemented"); + errors.add("The function Avg does not support inputs of type Null."); + errors.add("Coercion from [Utf8, Timestamp(Nanosecond, Some(\\\"+00:00\\\"))]"); + errors.add( + "Coercion from [Utf8, Float64, Utf8] to the signature OneOf([Exact([Utf8, Int64]), Exact([LargeUtf8, Int64]), Exact([Utf8, Int64, Utf8]), Exact([LargeUtf8, Int64, Utf8]), Exact([Utf8, Int64, LargeUtf8]), Exact([LargeUtf8, Int64, LargeUtf8])]) failed."); + errors.add("Coercion from"); + + errors.add("Error parsing timestamp"); + errors.add("lpad requested length"); + errors.add("rpad requested length"); + errors.add("No function matches the given name and argument types"); + return errors; + } + + public static ExpectedErrors expectedErrors() { + ExpectedErrors res = new ExpectedErrors(); + res.addAll(getExpectedErrors()); + return res; + } + +} diff --git a/src/sqlancer/cnosdb/CnosDBGlobalState.java b/src/sqlancer/cnosdb/CnosDBGlobalState.java new file mode 100644 index 000000000..9f34e03a5 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBGlobalState.java @@ -0,0 +1,28 @@ +package sqlancer.cnosdb; + +import sqlancer.ExecutionTimer; +import sqlancer.GlobalState; +import sqlancer.cnosdb.client.CnosDBConnection; +import sqlancer.common.query.Query; + +public class CnosDBGlobalState extends GlobalState { + + @Override + protected void executeEpilogue(Query q, boolean success, ExecutionTimer timer) throws Exception { + boolean logExecutionTime = getOptions().logExecutionTime(); + if (success && getOptions().printSucceedingStatements()) { + System.out.println(q.getQueryString()); + } + if (logExecutionTime) { + getLogger().writeCurrent(" -- " + timer.end().asString()); + } + if (q.couldAffectSchema()) { + updateSchema(); + } + } + + @Override + public CnosDBSchema readSchema() throws Exception { + return CnosDBSchema.fromConnection(getConnection()); + } +} diff --git a/src/sqlancer/cnosdb/CnosDBLoggableFactory.java b/src/sqlancer/cnosdb/CnosDBLoggableFactory.java new file mode 100644 index 000000000..407621c8b --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBLoggableFactory.java @@ -0,0 +1,55 @@ +package sqlancer.cnosdb; + +import java.io.PrintWriter; +import java.io.StringWriter; + +import sqlancer.cnosdb.query.CnosDBOtherQuery; +import sqlancer.cnosdb.query.CnosDBQueryAdapter; +import sqlancer.common.log.Loggable; +import sqlancer.common.log.LoggableFactory; +import sqlancer.common.log.LoggedString; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.Query; + +public class CnosDBLoggableFactory extends LoggableFactory { + + @Override + protected Loggable createLoggable(String input, String suffix) { + String completeString = input; + if (!input.endsWith(";")) { + completeString += ";"; + } + if (suffix != null && !suffix.isEmpty()) { + completeString += suffix; + } + return new LoggedString(completeString); + } + + @Override + public CnosDBQueryAdapter getQueryForStateToReproduce(String queryString) { + return new CnosDBOtherQuery(queryString, CnosDBExpectedError.expectedErrors()); + } + + @Override + public CnosDBQueryAdapter commentOutQuery(Query query) { + String queryString = query.getLogString(); + String newQueryString = "-- " + queryString; + ExpectedErrors errors = new ExpectedErrors(); + return new CnosDBOtherQuery(newQueryString, errors); + } + + @Override + protected Loggable infoToLoggable(String time, String databaseName, String databaseVersion, long seedValue) { + String sb = "-- Time: " + time + "\n" + "-- Database: " + databaseName + "\n" + "-- Database version: " + + databaseVersion + "\n" + "-- seed value: " + seedValue + "\n"; + return new LoggedString(sb); + } + + @Override + public Loggable convertStacktraceToLoggable(Throwable throwable) { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + throwable.printStackTrace(pw); + return new LoggedString("--" + sw.toString().replace("\n", "\n--")); + } +} diff --git a/src/sqlancer/cnosdb/CnosDBOptions.java b/src/sqlancer/cnosdb/CnosDBOptions.java new file mode 100644 index 000000000..f101c2d38 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBOptions.java @@ -0,0 +1,28 @@ +package sqlancer.cnosdb; + +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; + +@Parameters(separators = "=", commandDescription = "CnosDB (default port: " + CnosDBOptions.DEFAULT_PORT + + ", default host: " + CnosDBOptions.DEFAULT_HOST + ")") +public class CnosDBOptions implements DBMSSpecificOptions { + + public static final String DEFAULT_HOST = "localhost"; + public static final int DEFAULT_PORT = 31001; + + @Parameter(names = "--oracle", description = "Specifies which test oracle should be used for CnosDB") + public List oracle = List.of(CnosDBOracleFactory.QUERY_PARTITIONING); + + @Parameter(names = "--connection-url", description = "Specifies the URL for connecting to the CnosDB", arity = 1) + public String connectionURL = String.format("http://%s:%d", CnosDBOptions.DEFAULT_HOST, CnosDBOptions.DEFAULT_PORT); + + @Override + public List getTestOracleFactory() { + return oracle; + } + +} diff --git a/src/sqlancer/cnosdb/CnosDBOracleFactory.java b/src/sqlancer/cnosdb/CnosDBOracleFactory.java new file mode 100644 index 000000000..7cb9c4fc6 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBOracleFactory.java @@ -0,0 +1,39 @@ +package sqlancer.cnosdb; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.OracleFactory; +import sqlancer.cnosdb.oracle.CnosDBNoRECOracle; +import sqlancer.cnosdb.oracle.tlp.CnosDBTLPAggregateOracle; +import sqlancer.cnosdb.oracle.tlp.CnosDBTLPHavingOracle; +import sqlancer.cnosdb.oracle.tlp.CnosDBTLPWhereOracle; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.TestOracle; + +public enum CnosDBOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create(CnosDBGlobalState globalState) { + return new CnosDBNoRECOracle(globalState); + } + }, + HAVING { + @Override + public TestOracle create(CnosDBGlobalState globalState) { + return new CnosDBTLPHavingOracle(globalState); + } + + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(CnosDBGlobalState globalState) { + List> oracles = new ArrayList<>(); + oracles.add(new CnosDBTLPWhereOracle(globalState)); + oracles.add(new CnosDBTLPHavingOracle(globalState)); + oracles.add(new CnosDBTLPAggregateOracle(globalState)); + return new CompositeTestOracle<>(oracles, globalState); + } + } + +} diff --git a/src/sqlancer/cnosdb/CnosDBProvider.java b/src/sqlancer/cnosdb/CnosDBProvider.java new file mode 100644 index 000000000..8b69c53b3 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBProvider.java @@ -0,0 +1,123 @@ +package sqlancer.cnosdb; + +import java.util.Objects; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.ProviderAdapter; +import sqlancer.Randomly; +import sqlancer.StatementExecutor; +import sqlancer.cnosdb.client.CnosDBClient; +import sqlancer.cnosdb.client.CnosDBConnection; +import sqlancer.cnosdb.gen.CnosDBInsertGenerator; +import sqlancer.cnosdb.gen.CnosDBTableGenerator; +import sqlancer.cnosdb.query.CnosDBOtherQuery; +import sqlancer.cnosdb.query.CnosDBQueryProvider; +import sqlancer.common.log.LoggableFactory; + +@AutoService(DatabaseProvider.class) +public class CnosDBProvider extends ProviderAdapter { + + protected String username; + protected String password; + protected String host; + protected int port; + protected String databaseName; + + public CnosDBProvider() { + super(CnosDBGlobalState.class, CnosDBOptions.class); + } + + protected CnosDBProvider(Class globalClass, Class optionClass) { + super(globalClass, optionClass); + } + + protected static int mapActions(CnosDBGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + int nrPerformed; + if (Objects.requireNonNull(a) == Action.INSERT) { + nrPerformed = r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + } else { + throw new AssertionError(a); + } + return nrPerformed; + + } + + @Override + protected void checkViewsAreValid(CnosDBGlobalState globalState) { + } + + @Override + public void generateDatabase(CnosDBGlobalState globalState) throws Exception { + createTables(globalState, Randomly.fromOptions(4, 5, 6)); + prepareTables(globalState); + + } + + @Override + public CnosDBConnection createDatabase(CnosDBGlobalState globalState) throws Exception { + + username = globalState.getOptions().getUserName(); + password = globalState.getOptions().getPassword(); + host = globalState.getOptions().getHost(); + port = globalState.getOptions().getPort(); + databaseName = globalState.getDatabaseName(); + CnosDBClient client = new CnosDBClient(host, port, username, password, databaseName); + CnosDBConnection connection = new CnosDBConnection(client); + client.execute("DROP DATABASE IF EXISTS " + databaseName); + globalState.getState().logStatement("DROP DATABASE IF EXISTS " + databaseName); + client.execute("CREATE DATABASE " + databaseName); + globalState.getState().logStatement("CREATE DATABASE " + databaseName); + + return connection; + } + + protected void createTables(CnosDBGlobalState globalState, int numTables) throws Exception { + while (globalState.getSchema().getDatabaseTables().size() < numTables) { + String tableName = String.format("m%d", globalState.getSchema().getDatabaseTables().size()); + CnosDBOtherQuery createTable = CnosDBTableGenerator.generate(tableName); + globalState.executeStatement(createTable); + } + } + + protected void prepareTables(CnosDBGlobalState globalState) throws Exception { + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + CnosDBProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public String getDBMSName() { + return "CnosDB".toLowerCase(); + } + + @Override + public LoggableFactory getLoggableFactory() { + return new CnosDBLoggableFactory(); + } + + public enum Action implements AbstractAction { + INSERT(CnosDBInsertGenerator::insert); + + private final CnosDBQueryProvider sqlQueryProvider; + + Action(CnosDBQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public CnosDBOtherQuery getQuery(CnosDBGlobalState state) throws Exception { + return new CnosDBOtherQuery(sqlQueryProvider.getQuery(state).getQueryString(), + CnosDBExpectedError.expectedErrors()); + } + } + +} diff --git a/src/sqlancer/cnosdb/CnosDBSchema.java b/src/sqlancer/cnosdb/CnosDBSchema.java new file mode 100644 index 000000000..022969ce5 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBSchema.java @@ -0,0 +1,243 @@ +package sqlancer.cnosdb; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import sqlancer.Randomly; +import sqlancer.cnosdb.ast.CnosDBConstant; +import sqlancer.cnosdb.client.CnosDBConnection; +import sqlancer.cnosdb.client.CnosDBResultSet; +import sqlancer.common.schema.AbstractRowValue; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; + +public class CnosDBSchema extends AbstractSchema { + + private final String databaseName; + + public CnosDBSchema(List databaseTables, String databaseName) { + super(databaseTables); + this.databaseName = databaseName; + } + + public static CnosDBDataType getColumnType(String typeString) { + switch (typeString.toLowerCase()) { + case "bigint": + return CnosDBDataType.INT; + case "boolean": + return CnosDBDataType.BOOLEAN; + case "string": + return CnosDBDataType.STRING; + case "double": + return CnosDBDataType.DOUBLE; + case "bigint unsigned": + case "unsigned": + return CnosDBDataType.UINT; + case "timestamp(nanosecond)": + return CnosDBDataType.TIMESTAMP; + default: + throw new AssertionError(typeString); + } + } + + public static CnosDBSchema fromConnection(CnosDBConnection con) throws Exception { + CnosDBResultSet tablesRes = con.getClient().executeQuery("SHOW TABLES"); + + List tables = new ArrayList<>(); + while (tablesRes.next()) { + String tableName = tablesRes.getString(1); + List columns = getTableColumns(con, tableName); + tables.add(new CnosDBTable(tableName, columns)); + } + + return new CnosDBSchema(tables, con.getClient().getDatabase()); + } + + protected static List getTableColumns(CnosDBConnection con, String tableName) throws Exception { + CnosDBResultSet columnsRes = con.getClient().executeQuery("DESCRIBE TABLE " + tableName); + List columns = new ArrayList<>(); + CnosDBTable table = new CnosDBTable(tableName, columns); + while (columnsRes.next()) { + String columnName = columnsRes.getString(1); + String columnType = columnsRes.getString(3).toLowerCase(); + CnosDBDataType dataType = CnosDBSchema.getColumnType(columnsRes.getString(2)); + CnosDBColumn column; + if (columnType.contentEquals("time")) { + column = new CnosDBTimeColumn(); + } else if (columnType.contentEquals("tag")) { + column = new CnosDBTagColumn(columnName); + } else { + column = new CnosDBFieldColumn(columnName, dataType); + } + column.setTable(table); + columns.add(column); + } + + return columns; + } + + public CnosDBTables getRandomTableNonEmptyTables() { + return new CnosDBTables(Randomly.nonEmptySubset(getDatabaseTables())); + } + + public String getDatabaseName() { + return databaseName; + } + + public enum CnosDBDataType { + INT, BOOLEAN, STRING, DOUBLE, UINT, TIMESTAMP; + + public static CnosDBDataType getRandomType() { + return Randomly.fromOptions(values()); + } + + public static CnosDBDataType getRandomTypeWithoutTimeStamp() { + List dataTypes = new ArrayList<>(Arrays.asList(values())); + dataTypes.remove(TIMESTAMP); + return Randomly.fromList(dataTypes); + } + } + + public static class CnosDBColumn extends AbstractTableColumn { + + public CnosDBColumn(String name, CnosDBDataType columnType) { + super(name, null, columnType); + } + + public static CnosDBColumn createDummy(String name) { + return new CnosDBColumn(name, CnosDBDataType.INT); + } + + } + + public static class CnosDBTagColumn extends CnosDBColumn { + public CnosDBTagColumn(String name) { + super(name, CnosDBDataType.STRING); + } + } + + public static class CnosDBTimeColumn extends CnosDBColumn { + public CnosDBTimeColumn() { + super("TIME", CnosDBDataType.TIMESTAMP); + } + } + + public static class CnosDBFieldColumn extends CnosDBColumn { + public CnosDBFieldColumn(String name, CnosDBDataType columnType) { + super(name, columnType); + assert columnType != CnosDBDataType.TIMESTAMP; + } + } + + public static class CnosDBTables extends AbstractTables { + + public CnosDBTables(List tables) { + super(tables); + } + + public CnosDBRowValue getRandomRowValue(CnosDBConnection con) { + return null; + } + + public List getRandomColumnsWithOnlyOneField() { + ArrayList res = new ArrayList<>(); + this.getTables().forEach(table -> res.addAll(table.getRandomColumnsWithOnlyOneField())); + return res; + } + + } + + public static class CnosDBRowValue extends AbstractRowValue { + + protected CnosDBRowValue(CnosDBTables tables, Map values) { + super(tables, values); + } + + } + + public static class CnosDBTable extends AbstractTable { + + public CnosDBTable(String tableName, List columns) { + super(tableName, columns, null, false); + } + + @Override + public List getColumns() { + List res = super.getColumns(); + boolean hasTime = false; + for (CnosDBColumn column : res) { + if (column instanceof CnosDBTimeColumn) { + hasTime = true; + break; + } + } + assert hasTime; + + return res; + } + + public List getRandomColumnsWithOnlyOneField() { + ArrayList res = new ArrayList<>(); + boolean hasField = false; + for (CnosDBColumn column : getColumns()) { + if (column instanceof CnosDBTagColumn && Randomly.getBoolean()) { + res.add(column); + } else if (column instanceof CnosDBFieldColumn && !hasField) { + res.add(column); + hasField = true; + } + } + return res; + } + + // SELECT COUNT(*) FROM table; + @Override + public long getNrRows(CnosDBGlobalState globalState) { + long res; + try { + CnosDBResultSet tableCountRes = globalState.getConnection().getClient() + .executeQuery("SELECT COUNT(time) FROM " + this.name); + tableCountRes.next(); + res = tableCountRes.getLong(1); + } catch (Exception e) { + res = 0; + } + return res; + } + + @Override + public List getRandomNonEmptyColumnSubset() { + List selectedColumns = new ArrayList<>(); + ArrayList remainingColumns = new ArrayList<>(this.getColumns()); + + remainingColumns.removeIf(column -> column instanceof CnosDBTimeColumn); + CnosDBTimeColumn timeColumn = new CnosDBTimeColumn(); + timeColumn.setTable(this); + selectedColumns.add(timeColumn); + + remainingColumns.stream().filter(column -> column instanceof CnosDBTagColumn).findFirst().ifPresent(tag -> { + selectedColumns.add(tag); + remainingColumns.remove(tag); + }); + + remainingColumns.stream().filter(column -> column instanceof CnosDBFieldColumn).findFirst() + .ifPresent(field -> { + selectedColumns.add(field); + remainingColumns.remove(field); + }); + + int nr = Math.min(Randomly.smallNumber() + 1, remainingColumns.size()); + for (int i = 0; i < nr; i++) { + selectedColumns + .add(remainingColumns.remove((int) Randomly.getNotCachedInteger(0, remainingColumns.size()))); + } + return selectedColumns; + } + } + +} diff --git a/src/sqlancer/cnosdb/CnosDBToStringVisitor.java b/src/sqlancer/cnosdb/CnosDBToStringVisitor.java new file mode 100644 index 000000000..388e2ccd8 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBToStringVisitor.java @@ -0,0 +1,278 @@ +package sqlancer.cnosdb; + +import sqlancer.Randomly; +import sqlancer.cnosdb.ast.CnosDBAggregate; +import sqlancer.cnosdb.ast.CnosDBBetweenOperation; +import sqlancer.cnosdb.ast.CnosDBBinaryLogicalOperation; +import sqlancer.cnosdb.ast.CnosDBCastOperation; +import sqlancer.cnosdb.ast.CnosDBColumnValue; +import sqlancer.cnosdb.ast.CnosDBConstant; +import sqlancer.cnosdb.ast.CnosDBExpression; +import sqlancer.cnosdb.ast.CnosDBFunction; +import sqlancer.cnosdb.ast.CnosDBInOperation; +import sqlancer.cnosdb.ast.CnosDBJoin; +import sqlancer.cnosdb.ast.CnosDBLikeOperation; +import sqlancer.cnosdb.ast.CnosDBOrderByTerm; +import sqlancer.cnosdb.ast.CnosDBPostfixOperation; +import sqlancer.cnosdb.ast.CnosDBPostfixText; +import sqlancer.cnosdb.ast.CnosDBPrefixOperation; +import sqlancer.cnosdb.ast.CnosDBSelect; +import sqlancer.cnosdb.ast.CnosDBSelect.CnosDBFromTable; +import sqlancer.cnosdb.ast.CnosDBSelect.CnosDBSubquery; +import sqlancer.cnosdb.ast.CnosDBSimilarTo; +import sqlancer.common.visitor.BinaryOperation; +import sqlancer.common.visitor.ToStringVisitor; + +public final class CnosDBToStringVisitor extends ToStringVisitor implements CnosDBVisitor { + + @Override + public void visitSpecific(CnosDBExpression expr) { + CnosDBVisitor.super.visit(expr); + } + + @Override + public void visit(CnosDBConstant constant) { + sb.append(constant.getTextRepresentation()); + } + + @Override + public String get() { + return sb.toString(); + } + + @Override + public void visit(CnosDBPostfixOperation op) { + sb.append("("); + visit(op.getExpression()); + sb.append(")"); + sb.append(" "); + sb.append(op.getOperatorTextRepresentation()); + } + + @Override + public void visit(CnosDBColumnValue c) { + sb.append(c.getColumn().getFullQualifiedName()); + } + + @Override + public void visit(CnosDBPrefixOperation op) { + sb.append(op.getTextRepresentation()); + sb.append(" ("); + visit(op.getExpression()); + sb.append(")"); + } + + @Override + public void visit(CnosDBFromTable from) { + sb.append(from.getTable().getName()); + } + + @Override + public void visit(CnosDBSubquery subquery) { + sb.append("("); + visit(subquery.getSelect()); + sb.append(") AS "); + sb.append(subquery.getName()); + } + + @Override + public void visit(CnosDBSelect s) { + sb.append("SELECT "); + switch (s.getSelectOption()) { + case DISTINCT: + sb.append("DISTINCT "); + if (s.getDistinctOnClause() != null) { + sb.append("ON ("); + visit(s.getDistinctOnClause()); + sb.append(") "); + } + break; + case ALL: + sb.append(Randomly.fromOptions("ALL ", "")); + break; + default: + throw new AssertionError(); + } + if (s.getFetchColumns() == null) { + sb.append("*"); + } else { + visit(s.getFetchColumns()); + } + sb.append(" FROM "); + visit(s.getFromList()); + + for (CnosDBJoin j : s.getJoinClauses()) { + sb.append(" "); + switch (j.getType()) { + case INNER: + if (Randomly.getBoolean()) { + sb.append("INNER "); + } + sb.append("JOIN"); + break; + case LEFT: + sb.append("LEFT OUTER JOIN"); + break; + case RIGHT: + sb.append("RIGHT OUTER JOIN"); + break; + case FULL: + sb.append("FULL OUTER JOIN"); + break; + // case CROSS: + // sb.append("CROSS JOIN"); + // break; + default: + throw new AssertionError(j.getType()); + } + sb.append(" "); + visit(j.getTableReference()); + // if (j.getType() != CnosDBJoinType.CROSS) { + sb.append(" ON "); + visit(j.getOnClause()); + // } + } + + if (s.getWhereClause() != null) { + sb.append(" WHERE "); + visit(s.getWhereClause()); + } + if (!s.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(s.getGroupByExpressions()); + } + if (s.getHavingClause() != null) { + sb.append(" HAVING "); + visit(s.getHavingClause()); + + } + if (!s.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(s.getOrderByClauses()); + } + if (s.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(s.getLimitClause()); + } + + if (s.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(s.getOffsetClause()); + } + } + + @Override + public void visit(CnosDBOrderByTerm op) { + visit(op.getExpr()); + sb.append(" "); + sb.append(op.getOrder()); + } + + @Override + public void visit(CnosDBFunction f) { + sb.append(f.getFunctionName()); + sb.append("("); + int i = 0; + for (CnosDBExpression arg : f.getArguments()) { + if (i++ != 0) { + sb.append(", "); + } + visit(arg); + } + sb.append(")"); + } + + @Override + public void visit(CnosDBCastOperation cast) { + sb.append("CAST( "); + visit(cast.getExpression()); + sb.append(" AS "); + appendType(cast); + sb.append(")"); + } + + private void appendType(CnosDBCastOperation cast) { + CnosDBCompoundDataType compoundType = cast.getCompoundType(); + switch (compoundType.getDataType()) { + case BOOLEAN: + sb.append("BOOLEAN"); + break; + case INT: + sb.append("BIGINT"); + break; + case STRING: + sb.append(Randomly.fromOptions("STRING")); + break; + case DOUBLE: + sb.append("DOUBLE"); + break; + case UINT: + sb.append("BIGINT UNSIGNED"); + break; + case TIMESTAMP: + sb.append("TIMESTAMP"); + break; + + default: + throw new AssertionError(cast.getType()); + } + } + + @Override + public void visit(CnosDBBetweenOperation op) { + sb.append("("); + visit(op.getExpr()); + sb.append(") BETWEEN ("); + visit(op.getLeft()); + sb.append(") AND ("); + visit(op.getRight()); + sb.append(")"); + } + + @Override + public void visit(CnosDBInOperation op) { + sb.append("("); + visit(op.getExpr()); + sb.append(")"); + if (!op.isTrue()) { + sb.append(" NOT"); + } + sb.append(" IN ("); + visit(op.getListElements()); + sb.append(")"); + } + + @Override + public void visit(CnosDBPostfixText op) { + visit(op.getExpr()); + sb.append(op.getText()); + } + + @Override + public void visit(CnosDBAggregate op) { + sb.append(op.getFunction()); + sb.append("("); + visit(op.getArgs()); + sb.append(")"); + } + + @Override + public void visit(CnosDBSimilarTo op) { + sb.append("("); + visit(op.getString()); + sb.append(" SIMILAR TO "); + visit(op.getSimilarTo()); + sb.append(")"); + } + + @Override + public void visit(CnosDBBinaryLogicalOperation op) { + super.visit((BinaryOperation) op); + } + + @Override + public void visit(CnosDBLikeOperation op) { + super.visit((BinaryOperation) op); + } + +} diff --git a/src/sqlancer/cnosdb/CnosDBVisitor.java b/src/sqlancer/cnosdb/CnosDBVisitor.java new file mode 100644 index 000000000..7c1af7224 --- /dev/null +++ b/src/sqlancer/cnosdb/CnosDBVisitor.java @@ -0,0 +1,102 @@ +package sqlancer.cnosdb; + +import sqlancer.cnosdb.ast.CnosDBAggregate; +import sqlancer.cnosdb.ast.CnosDBBetweenOperation; +import sqlancer.cnosdb.ast.CnosDBBinaryLogicalOperation; +import sqlancer.cnosdb.ast.CnosDBCastOperation; +import sqlancer.cnosdb.ast.CnosDBColumnValue; +import sqlancer.cnosdb.ast.CnosDBConstant; +import sqlancer.cnosdb.ast.CnosDBExpression; +import sqlancer.cnosdb.ast.CnosDBFunction; +import sqlancer.cnosdb.ast.CnosDBInOperation; +import sqlancer.cnosdb.ast.CnosDBLikeOperation; +import sqlancer.cnosdb.ast.CnosDBOrderByTerm; +import sqlancer.cnosdb.ast.CnosDBPostfixOperation; +import sqlancer.cnosdb.ast.CnosDBPostfixText; +import sqlancer.cnosdb.ast.CnosDBPrefixOperation; +import sqlancer.cnosdb.ast.CnosDBSelect; +import sqlancer.cnosdb.ast.CnosDBSelect.CnosDBFromTable; +import sqlancer.cnosdb.ast.CnosDBSelect.CnosDBSubquery; +import sqlancer.cnosdb.ast.CnosDBSimilarTo; + +public interface CnosDBVisitor { + + static String asString(CnosDBExpression expr) { + CnosDBToStringVisitor visitor = new CnosDBToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } + + void visit(CnosDBConstant constant); + + void visit(CnosDBPostfixOperation op); + + void visit(CnosDBColumnValue c); + + void visit(CnosDBPrefixOperation op); + + void visit(CnosDBSelect op); + + void visit(CnosDBOrderByTerm op); + + void visit(CnosDBFunction f); + + void visit(CnosDBCastOperation cast); + + void visit(CnosDBBetweenOperation op); + + void visit(CnosDBInOperation op); + + void visit(CnosDBPostfixText op); + + void visit(CnosDBAggregate op); + + void visit(CnosDBFromTable from); + + void visit(CnosDBSubquery subquery); + + void visit(CnosDBBinaryLogicalOperation op); + + void visit(CnosDBLikeOperation op); + + void visit(CnosDBSimilarTo op); + + default void visit(CnosDBExpression expression) { + if (expression instanceof CnosDBConstant) { + visit((CnosDBConstant) expression); + } else if (expression instanceof CnosDBPostfixOperation) { + visit((CnosDBPostfixOperation) expression); + } else if (expression instanceof CnosDBColumnValue) { + visit((CnosDBColumnValue) expression); + } else if (expression instanceof CnosDBPrefixOperation) { + visit((CnosDBPrefixOperation) expression); + } else if (expression instanceof CnosDBSelect) { + visit((CnosDBSelect) expression); + } else if (expression instanceof CnosDBOrderByTerm) { + visit((CnosDBOrderByTerm) expression); + } else if (expression instanceof CnosDBFunction) { + visit((CnosDBFunction) expression); + } else if (expression instanceof CnosDBCastOperation) { + visit((CnosDBCastOperation) expression); + } else if (expression instanceof CnosDBBetweenOperation) { + visit((CnosDBBetweenOperation) expression); + } else if (expression instanceof CnosDBInOperation) { + visit((CnosDBInOperation) expression); + } else if (expression instanceof CnosDBAggregate) { + visit((CnosDBAggregate) expression); + } else if (expression instanceof CnosDBPostfixText) { + visit((CnosDBPostfixText) expression); + } else if (expression instanceof CnosDBSimilarTo) { + visit((CnosDBSimilarTo) expression); + } else if (expression instanceof CnosDBFromTable) { + visit((CnosDBFromTable) expression); + } else if (expression instanceof CnosDBSubquery) { + visit((CnosDBSubquery) expression); + } else if (expression instanceof CnosDBLikeOperation) { + visit((CnosDBLikeOperation) expression); + } else { + throw new AssertionError(expression); + } + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBAggregate.java b/src/sqlancer/cnosdb/ast/CnosDBAggregate.java new file mode 100644 index 000000000..df30717b4 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBAggregate.java @@ -0,0 +1,113 @@ +package sqlancer.cnosdb.ast; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBBugs; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.ast.CnosDBAggregate.CnosDBAggregateFunction; +import sqlancer.common.ast.FunctionNode; + +public class CnosDBAggregate extends FunctionNode + implements CnosDBExpression { + + public CnosDBAggregate(List args, CnosDBAggregateFunction func) { + super(func, args); + } + + public enum CnosDBAggregateFunction { + AVG(CnosDBDataType.DOUBLE), + MAX(CnosDBDataType.DOUBLE, CnosDBDataType.INT, CnosDBDataType.STRING, CnosDBDataType.TIMESTAMP, + CnosDBDataType.UINT), + MIN(CnosDBDataType.DOUBLE, CnosDBDataType.INT, CnosDBDataType.STRING, CnosDBDataType.TIMESTAMP, + CnosDBDataType.UINT), + COUNT(CnosDBDataType.INT) { + @Override + public CnosDBDataType[] getArgsTypes(CnosDBDataType returnType) { + return new CnosDBDataType[] { CnosDBDataType.getRandomType() }; + } + }, + SUM(CnosDBDataType.INT, CnosDBDataType.DOUBLE, CnosDBDataType.UINT), APPROX_MEDIAN(CnosDBDataType.DOUBLE), + + VAR(CnosDBDataType.DOUBLE), VAR_SAMP(CnosDBDataType.DOUBLE), VAR_POP(CnosDBDataType.DOUBLE), + STDDEV(CnosDBDataType.DOUBLE), STDDEV_SAMP(CnosDBDataType.DOUBLE), STDDEV_POP(CnosDBDataType.DOUBLE), + COVAR(CnosDBDataType.DOUBLE) { + @Override + public CnosDBDataType[] getArgsTypes(CnosDBDataType returnType) { + return new CnosDBDataType[] { CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE }; + } + }, + COVAR_SAMP(CnosDBDataType.DOUBLE) { + @Override + public CnosDBDataType[] getArgsTypes(CnosDBDataType returnType) { + return new CnosDBDataType[] { CnosDBDataType.DOUBLE, CnosDBDataType.INT }; + } + }, + CORR(CnosDBDataType.DOUBLE) { + @Override + public CnosDBDataType[] getArgsTypes(CnosDBDataType returnType) { + return new CnosDBDataType[] { CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE }; + } + }, + COVAR_POP(CnosDBDataType.DOUBLE) { + @Override + public CnosDBDataType[] getArgsTypes(CnosDBDataType returnType) { + return new CnosDBDataType[] { CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE }; + } + }, + + APPROX_PERCENTILE_CONT(CnosDBDataType.DOUBLE) { + @Override + public CnosDBDataType[] getArgsTypes(CnosDBDataType returnType) { + return new CnosDBDataType[] { CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE }; + } + }, + APPROX_PERCENTILE_CONT_WITH_WEIGHT(CnosDBDataType.DOUBLE) { + @Override + public CnosDBDataType[] getArgsTypes(CnosDBDataType returnType) { + return new CnosDBDataType[] { CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE }; + } + }, + APPROX_DISTINCT(CnosDBDataType.UINT), GROUPING(CnosDBDataType.INT), ARRAY_AGG(CnosDBDataType.STRING); + + private final CnosDBDataType[] supportedReturnTypes; + + CnosDBAggregateFunction(CnosDBDataType... supportedReturnTypes) { + this.supportedReturnTypes = supportedReturnTypes.clone(); + } + + public static List getAggregates(CnosDBDataType type) { + List res = Stream.of(values()).filter(p -> p.supportsReturnType(type)) + .collect(Collectors.toList()); + if (CnosDBBugs.BUG786) { + res.removeAll(List.of(VAR, VAR_POP, VAR_SAMP, STDDEV, STDDEV_POP, STDDEV_SAMP, CORR, COVAR, COVAR_POP, + COVAR_SAMP, APPROX_PERCENTILE_CONT_WITH_WEIGHT, APPROX_DISTINCT, APPROX_PERCENTILE_CONT, + APPROX_PERCENTILE_CONT_WITH_WEIGHT, GROUPING, ARRAY_AGG)); + } + + return res; + } + + public CnosDBDataType[] getArgsTypes(CnosDBDataType returnType) { + return new CnosDBDataType[] { returnType }; + } + + public boolean supportsReturnType(CnosDBDataType returnType) { + return Arrays.stream(supportedReturnTypes).anyMatch(t -> t == returnType) + || supportedReturnTypes.length == 0; + } + + public CnosDBDataType getRandomReturnType() { + if (supportedReturnTypes.length == 0) { + return Randomly.fromOptions(CnosDBDataType.getRandomType()); + } else { + return Randomly.fromOptions(supportedReturnTypes); + } + } + + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBAlias.java b/src/sqlancer/cnosdb/ast/CnosDBAlias.java new file mode 100644 index 000000000..86bba199f --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBAlias.java @@ -0,0 +1,35 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.common.visitor.UnaryOperation; + +public class CnosDBAlias implements UnaryOperation, CnosDBExpression { + + private final CnosDBExpression expr; + private final String alias; + + public CnosDBAlias(CnosDBExpression expr, String alias) { + this.expr = expr; + this.alias = alias; + } + + @Override + public CnosDBExpression getExpression() { + return expr; + } + + @Override + public String getOperatorRepresentation() { + return " as " + alias; + } + + @Override + public OperatorKind getOperatorKind() { + return OperatorKind.POSTFIX; + } + + @Override + public boolean omitBracketsWhenPrinting() { + return true; + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBBetweenOperation.java b/src/sqlancer/cnosdb/ast/CnosDBBetweenOperation.java new file mode 100644 index 000000000..d0addced1 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBBetweenOperation.java @@ -0,0 +1,34 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public final class CnosDBBetweenOperation implements CnosDBExpression { + + private final CnosDBExpression expr; + private final CnosDBExpression left; + private final CnosDBExpression right; + + public CnosDBBetweenOperation(CnosDBExpression expr, CnosDBExpression left, CnosDBExpression right) { + this.expr = expr; + this.left = left; + this.right = right; + } + + public CnosDBExpression getExpr() { + return expr; + } + + public CnosDBExpression getLeft() { + return left; + } + + public CnosDBExpression getRight() { + return right; + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.BOOLEAN; + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBBinaryArithmeticOperation.java b/src/sqlancer/cnosdb/ast/CnosDBBinaryArithmeticOperation.java new file mode 100644 index 000000000..acf3e93d5 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBBinaryArithmeticOperation.java @@ -0,0 +1,69 @@ +package sqlancer.cnosdb.ast; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.ast.CnosDBBinaryArithmeticOperation.CnosDBBinaryOperator; +import sqlancer.common.ast.BinaryOperatorNode; + +public class CnosDBBinaryArithmeticOperation extends BinaryOperatorNode + implements CnosDBExpression { + + public CnosDBBinaryArithmeticOperation(CnosDBExpression left, CnosDBExpression right, CnosDBBinaryOperator op) { + super(left, right, op); + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.INT; + } + + public enum CnosDBBinaryOperator implements BinaryOperatorNode.Operator { + + ADDITION("+") { + }, + SUBTRACTION("-") { + }, + MULTIPLICATION("*") { + }, + DIVISION("/") { + + }, + MODULO("%") { + }, + EXPONENTIATION("^") { + }; + + private final String textRepresentation; + + CnosDBBinaryOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static CnosDBBinaryOperator getRandom(CnosDBDataType dataType) { + List ops = new ArrayList<>(Arrays.asList(values())); + switch (dataType) { + case DOUBLE: + case UINT: + case STRING: + ops.remove(EXPONENTIATION); + ops.remove(MODULO); + break; + default: + break; + } + + return Randomly.fromList(ops); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBBinaryComparisonOperation.java b/src/sqlancer/cnosdb/ast/CnosDBBinaryComparisonOperation.java new file mode 100644 index 000000000..af38849c9 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBBinaryComparisonOperation.java @@ -0,0 +1,57 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.ast.CnosDBBinaryComparisonOperation.CnosDBBinaryComparisonOperator; +import sqlancer.common.ast.BinaryOperatorNode; + +public class CnosDBBinaryComparisonOperation + extends BinaryOperatorNode implements CnosDBExpression { + + public CnosDBBinaryComparisonOperation(CnosDBExpression left, CnosDBExpression right, + CnosDBBinaryComparisonOperator op) { + super(left, right, op); + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.BOOLEAN; + } + + public enum CnosDBBinaryComparisonOperator implements BinaryOperatorNode.Operator { + EQUALS("=") { + }, + IS_DISTINCT("IS DISTINCT FROM") { + }, + IS_NOT_DISTINCT("IS NOT DISTINCT FROM") { + }, + NOT_EQUALS("!=") { + }, + LESS("<") { + }, + LESS_EQUALS("<=") { + }, + GREATER(">") { + }, + GREATER_EQUALS(">=") { + + }; + + private final String textRepresentation; + + CnosDBBinaryComparisonOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static CnosDBBinaryComparisonOperator getRandom() { + return Randomly.fromOptions(CnosDBBinaryComparisonOperator.values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBBinaryLogicalOperation.java b/src/sqlancer/cnosdb/ast/CnosDBBinaryLogicalOperation.java new file mode 100644 index 000000000..bad8a3b75 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBBinaryLogicalOperation.java @@ -0,0 +1,33 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.ast.CnosDBBinaryLogicalOperation.BinaryLogicalOperator; +import sqlancer.common.ast.BinaryOperatorNode; + +public class CnosDBBinaryLogicalOperation extends BinaryOperatorNode + implements CnosDBExpression { + + public CnosDBBinaryLogicalOperation(CnosDBExpression left, CnosDBExpression right, BinaryLogicalOperator op) { + super(left, right, op); + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.BOOLEAN; + } + + public enum BinaryLogicalOperator implements BinaryOperatorNode.Operator { + AND, OR; + + public static BinaryLogicalOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return toString(); + } + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBCastOperation.java b/src/sqlancer/cnosdb/ast/CnosDBCastOperation.java new file mode 100644 index 000000000..41db62d81 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBCastOperation.java @@ -0,0 +1,60 @@ +package sqlancer.cnosdb.ast; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.cnosdb.CnosDBCompoundDataType; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public class CnosDBCastOperation implements CnosDBExpression { + + private final CnosDBExpression expression; + private final CnosDBCompoundDataType type; + + public CnosDBCastOperation(CnosDBExpression expression, CnosDBCompoundDataType type) { + if (expression == null) { + throw new AssertionError(); + } + this.expression = expression; + this.type = type; + } + + public static List canCastTo(CnosDBDataType dataType) { + List options = new ArrayList<>(Arrays.asList(CnosDBDataType.values())); + + switch (dataType) { + case UINT: + case BOOLEAN: + case DOUBLE: + options.remove(CnosDBDataType.TIMESTAMP); + break; + case TIMESTAMP: + options.remove(CnosDBDataType.BOOLEAN); + options.remove(CnosDBDataType.UINT); + options.remove(CnosDBDataType.DOUBLE); + break; + default: + break; + } + return options; + } + + @Override + public CnosDBDataType getExpressionType() { + return type.getDataType(); + } + + public CnosDBExpression getExpression() { + return expression; + } + + public CnosDBDataType getType() { + return type.getDataType(); + } + + public CnosDBCompoundDataType getCompoundType() { + return type; + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBColumnValue.java b/src/sqlancer/cnosdb/ast/CnosDBColumnValue.java new file mode 100644 index 000000000..f90b6120f --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBColumnValue.java @@ -0,0 +1,27 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBColumn; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public class CnosDBColumnValue implements CnosDBExpression { + + private final CnosDBColumn c; + + public CnosDBColumnValue(CnosDBColumn c) { + this.c = c; + } + + public static CnosDBColumnValue create(CnosDBColumn c) { + return new CnosDBColumnValue(c); + } + + @Override + public CnosDBDataType getExpressionType() { + return c.getType(); + } + + public CnosDBColumn getColumn() { + return c; + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBConcatOperation.java b/src/sqlancer/cnosdb/ast/CnosDBConcatOperation.java new file mode 100644 index 000000000..6821f83b8 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBConcatOperation.java @@ -0,0 +1,22 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.common.ast.BinaryNode; + +public class CnosDBConcatOperation extends BinaryNode implements CnosDBExpression { + + public CnosDBConcatOperation(CnosDBExpression left, CnosDBExpression right) { + super(left, right); + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.STRING; + } + + @Override + public String getOperatorRepresentation() { + return "||"; + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBConstant.java b/src/sqlancer/cnosdb/ast/CnosDBConstant.java new file mode 100644 index 000000000..42ecd3908 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBConstant.java @@ -0,0 +1,520 @@ +package sqlancer.cnosdb.ast; + +import java.math.BigDecimal; +import java.text.SimpleDateFormat; +import java.util.Date; + +import sqlancer.IgnoreMeException; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public abstract class CnosDBConstant implements CnosDBExpression { + + public static CnosDBConstant createNullConstant() { + return new CnosDBNullConstant(); + } + + public static CnosDBConstant createIntConstant(long val) { + return new IntConstant(val, false); + } + + public static CnosDBConstant createBooleanConstant(boolean val) { + return new BooleanConstant(val); + } + + public static CnosDBConstant createFalse() { + return createBooleanConstant(false); + } + + public static CnosDBConstant createTrue() { + return createBooleanConstant(true); + } + + public static CnosDBConstant createStringConstant(String string) { + return new StringConstant(string); + } + + public static CnosDBConstant createDoubleConstant(double val) { + return new DoubleConstant(val); + } + + public static CnosDBConstant createUintConstant(long val) { + return new IntConstant(val, true); + } + + public static CnosDBConstant createTimeStampConstant(long val) { + return new TimeStampConstant(val); + } + + public abstract String getTextRepresentation(); + + public String asString() { + throw new UnsupportedOperationException(this.toString()); + } + + public boolean isString() { + return false; + } + + public boolean isNull() { + return false; + } + + public boolean asBoolean() { + throw new UnsupportedOperationException(this.toString()); + } + + public long asInt() { + throw new UnsupportedOperationException(this.toString()); + } + + public double asDouble() { + throw new UnsupportedOperationException(this.toString()); + } + + public boolean isBoolean() { + return false; + } + + public abstract CnosDBConstant isEquals(CnosDBConstant rightVal); + + public boolean isInt() { + return false; + } + + protected abstract CnosDBConstant isLessThan(CnosDBConstant rightVal); + + @Override + public String toString() { + return getTextRepresentation(); + } + + public abstract CnosDBConstant cast(CnosDBDataType type); + + public static class BooleanConstant extends CnosDBConstant { + + private final boolean value; + + public BooleanConstant(boolean value) { + this.value = value; + } + + @Override + public String getTextRepresentation() { + return value ? "TRUE" : "FALSE"; + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.BOOLEAN; + } + + @Override + public boolean asBoolean() { + return value; + } + + @Override + public boolean isBoolean() { + return true; + } + + @Override + public CnosDBConstant isEquals(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return CnosDBConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return CnosDBConstant.createBooleanConstant(value == rightVal.asBoolean()); + } else if (rightVal.isString()) { + return CnosDBConstant.createBooleanConstant(value == rightVal.cast(CnosDBDataType.BOOLEAN).asBoolean()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + protected CnosDBConstant isLessThan(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return CnosDBConstant.createNullConstant(); + } else if (rightVal.isString()) { + return isLessThan(rightVal.cast(CnosDBDataType.BOOLEAN)); + } else { + assert rightVal.isBoolean(); + return CnosDBConstant.createBooleanConstant((value ? 1 : 0) < (rightVal.asBoolean() ? 1 : 0)); + } + } + + @Override + public CnosDBConstant cast(CnosDBDataType type) { + switch (type) { + case BOOLEAN: + return this; + case INT: + return CnosDBConstant.createIntConstant(value ? 1 : 0); + case UINT: + return CnosDBConstant.createUintConstant(value ? 1 : 0); + case STRING: + return CnosDBConstant.createStringConstant(value ? "true" : "false"); + default: + return null; + } + } + + } + + public static class CnosDBNullConstant extends CnosDBConstant { + + @Override + public String getTextRepresentation() { + return "NULL"; + } + + @Override + public CnosDBDataType getExpressionType() { + return null; + } + + @Override + public boolean isNull() { + return true; + } + + @Override + public CnosDBConstant isEquals(CnosDBConstant rightVal) { + return CnosDBConstant.createNullConstant(); + } + + @Override + protected CnosDBConstant isLessThan(CnosDBConstant rightVal) { + return CnosDBConstant.createNullConstant(); + } + + @Override + public CnosDBConstant cast(CnosDBDataType type) { + return CnosDBConstant.createNullConstant(); + } + } + + public static class StringConstant extends CnosDBConstant { + + private final String value; + + public StringConstant(String value) { + this.value = value; + } + + @Override + public String getTextRepresentation() { + return String.format("'%s'", value.replace("'", "''")); + } + + @Override + public CnosDBConstant isEquals(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return CnosDBConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return cast(CnosDBDataType.INT).isEquals(rightVal.cast(CnosDBDataType.INT)); + } else if (rightVal.isBoolean()) { + return cast(CnosDBDataType.BOOLEAN).isEquals(rightVal.cast(CnosDBDataType.BOOLEAN)); + } else if (rightVal.isString()) { + return CnosDBConstant.createBooleanConstant(value.contentEquals(rightVal.asString())); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + protected CnosDBConstant isLessThan(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return CnosDBConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return cast(CnosDBDataType.INT).isLessThan(rightVal.cast(CnosDBDataType.INT)); + } else if (rightVal.isBoolean()) { + return cast(CnosDBDataType.BOOLEAN).isLessThan(rightVal.cast(CnosDBDataType.BOOLEAN)); + } else if (rightVal.isString()) { + return CnosDBConstant.createBooleanConstant(value.compareTo(rightVal.asString()) < 0); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + public CnosDBConstant cast(CnosDBDataType type) { + if (type == CnosDBDataType.STRING) { + return this; + } + String s = value.trim(); + switch (type) { + case BOOLEAN: + try { + return CnosDBConstant.createBooleanConstant(Long.parseLong(s) != 0); + } catch (NumberFormatException ignored) { + } + switch (s.toUpperCase()) { + case "T": + case "TR": + case "TRU": + case "TRUE": + case "1": + case "YES": + case "YE": + case "Y": + case "ON": + return CnosDBConstant.createTrue(); + case "F": + case "FA": + case "FAL": + case "FALS": + case "FALSE": + case "N": + case "NO": + case "OF": + case "OFF": + default: + return CnosDBConstant.createFalse(); + } + case INT: + try { + return CnosDBConstant.createIntConstant(Long.parseLong(s)); + } catch (NumberFormatException e) { + return CnosDBConstant.createIntConstant(-1); + } + case UINT: + try { + return CnosDBConstant.createUintConstant(Long.parseUnsignedLong(s)); + } catch (NumberFormatException e) { + return CnosDBConstant.createUintConstant(0); + } + case DOUBLE: + try { + return CnosDBConstant.createDoubleConstant(Double.parseDouble(s)); + } catch (NumberFormatException e) { + return CnosDBConstant.createDoubleConstant(0.0); + } + + default: + return null; + } + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.STRING; + } + + @Override + public boolean isString() { + return true; + } + + @Override + public String asString() { + return value; + } + + } + + public static class IntConstant extends CnosDBConstant { + + private final long val; + private final boolean unsigned; + + public IntConstant(long val, boolean unsigned) { + this.val = val; + this.unsigned = unsigned; + } + + @Override + public String getTextRepresentation() { + if (unsigned) { + return Long.toUnsignedString(val); + } else { + return String.valueOf(val); + } + } + + @Override + public CnosDBDataType getExpressionType() { + if (unsigned) { + return CnosDBDataType.UINT; + } + return CnosDBDataType.INT; + } + + @Override + public long asInt() { + return val; + } + + @Override + public double asDouble() { + return val; + } + + @Override + public boolean isInt() { + return true; + } + + @Override + public CnosDBConstant isEquals(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return CnosDBConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return cast(CnosDBDataType.BOOLEAN).isEquals(rightVal); + } else if (rightVal.isInt()) { + return CnosDBConstant.createBooleanConstant(val == rightVal.asInt()); + } else if (rightVal.isString()) { + return CnosDBConstant.createBooleanConstant(val == rightVal.cast(CnosDBDataType.INT).asInt()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + protected CnosDBConstant isLessThan(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return CnosDBConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return CnosDBConstant.createBooleanConstant(val < rightVal.asInt()); + } else if (rightVal.isBoolean()) { + throw new AssertionError(rightVal); + } else if (rightVal.getExpressionType() == CnosDBDataType.UINT) { + return CnosDBConstant.createBooleanConstant(Long.compareUnsigned(val, rightVal.asInt()) < 0); + } else if (rightVal.isString()) { + return CnosDBConstant.createBooleanConstant(val < rightVal.cast(CnosDBDataType.INT).asInt()); + } else { + throw new IgnoreMeException(); + } + + } + + @Override + public CnosDBConstant cast(CnosDBDataType type) { + switch (type) { + case BOOLEAN: + return CnosDBConstant.createBooleanConstant(val != 0); + case INT: + return CnosDBConstant.createIntConstant(val); + case STRING: + return CnosDBConstant.createStringConstant(String.valueOf(val)); + case UINT: + return CnosDBConstant.createUintConstant(val); + case DOUBLE: + return CnosDBConstant.createDoubleConstant(val); + default: + return null; + } + } + } + + public static class TimeStampConstant extends CnosDBConstant { + final long val; + + TimeStampConstant(long time) { + val = time; + } + + @Override + public String getTextRepresentation() { + return "CAST (" + val + " AS TIMESTAMP)"; + } + + @Override + public CnosDBConstant isEquals(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return createNullConstant(); + } else if (rightVal.getExpressionType() == CnosDBDataType.TIMESTAMP) { + return createBooleanConstant(val == rightVal.asInt()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + protected CnosDBConstant isLessThan(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return CnosDBConstant.createNullConstant(); + } else if (rightVal.getExpressionType() == CnosDBDataType.TIMESTAMP) { + return CnosDBConstant.createBooleanConstant(val < rightVal.asInt()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + public CnosDBConstant cast(CnosDBDataType type) { + switch (type) { + case INT: + return createIntConstant(val); + case STRING: + final SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss"); + return CnosDBConstant.createStringConstant(dateFormat.format(new Date(val))); + default: + return null; + } + } + + @Override + public long asInt() { + return val; + } + + } + + public static class DoubleConstant extends CnosDBConstant { + + private final double val; + + public DoubleConstant(double val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + if (Double.isFinite(val)) { + BigDecimal bigDecimal = new BigDecimal(val); + return bigDecimal.toPlainString(); + } else { + return String.valueOf(0.0); + } + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.DOUBLE; + } + + @Override + public boolean isNull() { + return false; + } + + @Override + protected CnosDBConstant isLessThan(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return CnosDBConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return cast(CnosDBDataType.BOOLEAN).isLessThan(rightVal); + } else { + return CnosDBConstant.createBooleanConstant(val < rightVal.cast(CnosDBDataType.DOUBLE).asDouble()); + } + } + + @Override + public CnosDBConstant isEquals(CnosDBConstant rightVal) { + if (rightVal.isNull()) { + return CnosDBConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return cast(CnosDBDataType.BOOLEAN).isEquals(rightVal); + } else { + return CnosDBConstant.createBooleanConstant(val == rightVal.cast(CnosDBDataType.DOUBLE).asDouble()); + } + } + + @Override + public CnosDBConstant cast(CnosDBDataType type) { + return null; + } + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBExpression.java b/src/sqlancer/cnosdb/ast/CnosDBExpression.java new file mode 100644 index 000000000..63997a0f5 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBExpression.java @@ -0,0 +1,14 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public interface CnosDBExpression { + + default CnosDBDataType getExpressionType() { + return null; + } + + default CnosDBConstant getExpectedValue() { + throw new AssertionError("Not impl"); + } +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBFunction.java b/src/sqlancer/cnosdb/ast/CnosDBFunction.java new file mode 100644 index 000000000..7a35d703e --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBFunction.java @@ -0,0 +1,30 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public class CnosDBFunction implements CnosDBExpression { + + private final String func; + private final CnosDBExpression[] args; + private final CnosDBDataType returnType; + + public CnosDBFunction(CnosDBFunctionWithUnknownResult f, CnosDBDataType returnType, CnosDBExpression... args) { + this.func = f.getName(); + this.returnType = returnType; + this.args = args.clone(); + } + + public String getFunctionName() { + return func; + } + + public CnosDBExpression[] getArguments() { + return args.clone(); + } + + @Override + public CnosDBDataType getExpressionType() { + return returnType; + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBFunctionWithUnknownResult.java b/src/sqlancer/cnosdb/ast/CnosDBFunctionWithUnknownResult.java new file mode 100644 index 000000000..485f2309d --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBFunctionWithUnknownResult.java @@ -0,0 +1,104 @@ +package sqlancer.cnosdb.ast; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import sqlancer.cnosdb.CnosDBBugs; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.gen.CnosDBExpressionGenerator; + +public enum CnosDBFunctionWithUnknownResult { + + // String functions + ASCII("ascii", CnosDBDataType.INT, CnosDBDataType.STRING), + BTRIM("btrim", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.STRING), + CHAR_LENGTH("char_length", CnosDBDataType.INT, CnosDBDataType.STRING), + CHARACTER_LENGTH("character_length", CnosDBDataType.INT, CnosDBDataType.STRING), + CONCAT("concat", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.STRING), + CONCAT_WS("concat_ws", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.STRING), + CHR("chr", CnosDBDataType.STRING, CnosDBDataType.INT), + BIT_LENGTH("bit_length", CnosDBDataType.INT, CnosDBDataType.STRING), + INITCAP("initcap", CnosDBDataType.STRING, CnosDBDataType.STRING), + + LEFT("left", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.INT), + LENGTH("length", CnosDBDataType.UINT, CnosDBDataType.STRING), + LOWER("lower", CnosDBDataType.STRING, CnosDBDataType.STRING), + UPPER("upper", CnosDBDataType.STRING, CnosDBDataType.STRING), + LPAD3("lpad", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.INT, CnosDBDataType.STRING), + LPAD2("lpad", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.INT), + RPAD3("rpad", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.INT, CnosDBDataType.STRING), + RPAD2("rpad", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.INT), + LTRIM("ltrim", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.STRING), + OCTET_LENGTH("octet_length", CnosDBDataType.INT, CnosDBDataType.STRING), + // REPEAT("repeat", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.INT), + REPLACE("replace", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.STRING), + REVERSE("reverse", CnosDBDataType.STRING, CnosDBDataType.STRING), + RIGHT("right", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.INT), + RTRIM("rtrim", CnosDBDataType.STRING, CnosDBDataType.STRING), + SPLIT_PART("split_part", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.INT), + STARTS_WITH("starts_with", CnosDBDataType.BOOLEAN, CnosDBDataType.STRING, CnosDBDataType.STRING), + STRPOS("strpos", CnosDBDataType.INT, CnosDBDataType.STRING, CnosDBDataType.STRING), + SUBSTR("substr", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.INT, CnosDBDataType.INT), + TRANSLATE("translate", CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.STRING, CnosDBDataType.STRING), + MD5("md5", CnosDBDataType.STRING, CnosDBDataType.STRING), + // mathematical functions + ABS("abs", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + CEIL("ceil", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + EXP("exp", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), LN("ln", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + LOG2("log2", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + LOG10("log10", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + POWER("power", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + ROUND("round", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + TRUNC("trunc", CnosDBDataType.DOUBLE, CnosDBDataType.INT), + FLOOR("floor", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + SIGNUM("signum", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + ACOS("acos", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + ASIN("asin", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + ATAN2("atan2", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + COS("cos", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), SIN("sin", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + SQRT("sqrt", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + TAN("tan", CnosDBDataType.DOUBLE, CnosDBDataType.DOUBLE), + DATE_PART("date_part", CnosDBDataType.INT, CnosDBDataType.STRING, CnosDBDataType.TIMESTAMP), + TO_TIMESTAMP("to_timestamp", CnosDBDataType.TIMESTAMP, CnosDBDataType.INT), + TO_TIMESTAMP_MILLIS("to_timestamp_millis", CnosDBDataType.TIMESTAMP, CnosDBDataType.INT), + TO_TIMESTAMP_MICROS("to_timestamp_micros", CnosDBDataType.TIMESTAMP, CnosDBDataType.INT), + TO_TIMESTAMP_SECONDS("to_timestamp_seconds", CnosDBDataType.TIMESTAMP, CnosDBDataType.INT); + + private final String functionName; + private final CnosDBDataType returnType; + private final CnosDBDataType[] argTypes; + + CnosDBFunctionWithUnknownResult(String functionName, CnosDBDataType returnType, CnosDBDataType... indexType) { + this.functionName = functionName; + this.returnType = returnType; + this.argTypes = indexType.clone(); + + } + + public static List getSupportedFunctions(CnosDBDataType type) { + List res = Stream.of(values()) + .filter(function -> function.isCompatibleWithReturnType(type)).collect(Collectors.toList()); + if (CnosDBBugs.BUG3547) { + res.removeAll(List.of(TO_TIMESTAMP, TO_TIMESTAMP_MICROS, TO_TIMESTAMP_MILLIS, TO_TIMESTAMP_SECONDS)); + } + return res; + } + + public boolean isCompatibleWithReturnType(CnosDBDataType t) { + return t == returnType; + } + + public CnosDBExpression[] getArguments(CnosDBDataType ignore, CnosDBExpressionGenerator gen, int depth) { + CnosDBExpression[] args = new CnosDBExpression[argTypes.length]; + for (int i = 0; i < args.length; i++) { + args[i] = gen.generateExpression(depth, argTypes[i]); + } + return args; + } + + public String getName() { + return functionName; + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBInOperation.java b/src/sqlancer/cnosdb/ast/CnosDBInOperation.java new file mode 100644 index 000000000..c0ffd34ed --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBInOperation.java @@ -0,0 +1,35 @@ +package sqlancer.cnosdb.ast; + +import java.util.List; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public class CnosDBInOperation implements CnosDBExpression { + + private final CnosDBExpression expr; + private final List listElements; + private final boolean isTrue; + + public CnosDBInOperation(CnosDBExpression expr, List listElements, boolean isTrue) { + this.expr = expr; + this.listElements = listElements; + this.isTrue = isTrue; + } + + public CnosDBExpression getExpr() { + return expr; + } + + public List getListElements() { + return listElements; + } + + public boolean isTrue() { + return isTrue; + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.BOOLEAN; + } +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBJoin.java b/src/sqlancer/cnosdb/ast/CnosDBJoin.java new file mode 100644 index 000000000..eea88466f --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBJoin.java @@ -0,0 +1,46 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public class CnosDBJoin implements CnosDBExpression { + + private final CnosDBExpression tableReference; + private final CnosDBExpression onClause; + private final CnosDBJoinType type; + + public CnosDBJoin(CnosDBExpression tableReference, CnosDBExpression onClause, CnosDBJoinType type) { + this.tableReference = tableReference; + this.onClause = onClause; + this.type = type; + } + + public CnosDBExpression getTableReference() { + return tableReference; + } + + public CnosDBExpression getOnClause() { + return onClause; + } + + public CnosDBJoinType getType() { + return type; + } + + @Override + public CnosDBDataType getExpressionType() { + throw new AssertionError(); + } + + public enum CnosDBJoinType { + INNER, LEFT, RIGHT, FULL; + // now not support + // CROSS; + + public static CnosDBJoinType getRandom() { + return Randomly.fromOptions(values()); + } + + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBLikeOperation.java b/src/sqlancer/cnosdb/ast/CnosDBLikeOperation.java new file mode 100644 index 000000000..616cd39ee --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBLikeOperation.java @@ -0,0 +1,22 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.common.ast.BinaryNode; + +public class CnosDBLikeOperation extends BinaryNode implements CnosDBExpression { + + public CnosDBLikeOperation(CnosDBExpression left, CnosDBExpression right) { + super(left, right); + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.BOOLEAN; + } + + @Override + public String getOperatorRepresentation() { + return "LIKE"; + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBOrderByTerm.java b/src/sqlancer/cnosdb/ast/CnosDBOrderByTerm.java new file mode 100644 index 000000000..de5812d76 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBOrderByTerm.java @@ -0,0 +1,37 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public class CnosDBOrderByTerm implements CnosDBExpression { + + private final CnosDBOrder order; + private final CnosDBExpression expr; + + public CnosDBOrderByTerm(CnosDBExpression expr, CnosDBOrder order) { + this.expr = expr; + this.order = order; + } + + public CnosDBOrder getOrder() { + return order; + } + + public CnosDBExpression getExpr() { + return expr; + } + + @Override + public CnosDBDataType getExpressionType() { + return null; + } + + public enum CnosDBOrder { + ASC, DESC; + + public static CnosDBOrder getRandomOrder() { + return Randomly.fromOptions(CnosDBOrder.values()); + } + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBPostfixOperation.java b/src/sqlancer/cnosdb/ast/CnosDBPostfixOperation.java new file mode 100644 index 000000000..f37621f44 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBPostfixOperation.java @@ -0,0 +1,97 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.common.ast.BinaryOperatorNode.Operator; + +public class CnosDBPostfixOperation implements CnosDBExpression { + + private final CnosDBExpression expr; + private final String operatorTextRepresentation; + + public CnosDBPostfixOperation(CnosDBExpression expr, PostfixOperator op) { + this.expr = expr; + this.operatorTextRepresentation = Randomly.fromOptions(op.textRepresentations); + } + + public static CnosDBExpression create(CnosDBExpression expr, PostfixOperator op) { + return new CnosDBPostfixOperation(expr, op); + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.BOOLEAN; + } + + public String getOperatorTextRepresentation() { + return operatorTextRepresentation; + } + + public CnosDBExpression getExpression() { + return expr; + } + + public enum PostfixOperator implements Operator { + IS_NULL("IS NULL"/* , "ISNULL" */) { + @Override + public CnosDBDataType[] getInputDataTypes() { + return CnosDBDataType.values(); + } + + }, + IS_UNKNOWN("IS UNKNOWN") { + @Override + public CnosDBDataType[] getInputDataTypes() { + return new CnosDBDataType[] { CnosDBDataType.BOOLEAN }; + } + }, + + IS_NOT_NULL("IS NOT NULL"/* "NOTNULL" */) { + + @Override + public CnosDBDataType[] getInputDataTypes() { + return CnosDBDataType.values(); + } + + }, + IS_NOT_UNKNOWN("IS NOT UNKNOWN") { + + @Override + public CnosDBDataType[] getInputDataTypes() { + return new CnosDBDataType[] { CnosDBDataType.BOOLEAN }; + } + }, + IS_TRUE("IS TRUE") { + @Override + public CnosDBDataType[] getInputDataTypes() { + return new CnosDBDataType[] { CnosDBDataType.BOOLEAN }; + } + + }, + IS_FALSE("IS FALSE") { + @Override + public CnosDBDataType[] getInputDataTypes() { + return new CnosDBDataType[] { CnosDBDataType.BOOLEAN }; + } + + }; + + private final String[] textRepresentations; + + PostfixOperator(String... textRepresentations) { + this.textRepresentations = textRepresentations.clone(); + } + + public static PostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + public abstract CnosDBDataType[] getInputDataTypes(); + + @Override + public String getTextRepresentation() { + return toString(); + } + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBPostfixText.java b/src/sqlancer/cnosdb/ast/CnosDBPostfixText.java new file mode 100644 index 000000000..241fab89a --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBPostfixText.java @@ -0,0 +1,29 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public class CnosDBPostfixText implements CnosDBExpression { + + private final CnosDBExpression expr; + private final String text; + private final CnosDBDataType type; + + public CnosDBPostfixText(CnosDBExpression expr, String text, CnosDBDataType type) { + this.expr = expr; + this.text = text; + this.type = type; + } + + public CnosDBExpression getExpr() { + return expr; + } + + public String getText() { + return text; + } + + @Override + public CnosDBDataType getExpressionType() { + return type; + } +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBPrefixOperation.java b/src/sqlancer/cnosdb/ast/CnosDBPrefixOperation.java new file mode 100644 index 000000000..db37f0089 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBPrefixOperation.java @@ -0,0 +1,73 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.common.ast.BinaryOperatorNode.Operator; + +public class CnosDBPrefixOperation implements CnosDBExpression { + + private final CnosDBExpression expr; + private final PrefixOperator op; + + public CnosDBPrefixOperation(CnosDBExpression expr, PrefixOperator op) { + this.expr = expr; + this.op = op; + } + + @Override + public CnosDBDataType getExpressionType() { + return op.getExpressionType(); + } + + public CnosDBDataType[] getInputDataTypes() { + return op.dataTypes; + } + + public String getTextRepresentation() { + return op.textRepresentation; + } + + public CnosDBExpression getExpression() { + return expr; + } + + public enum PrefixOperator implements Operator { + NOT("NOT", CnosDBDataType.BOOLEAN) { + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.BOOLEAN; + } + + }, + UNARY_PLUS("+", CnosDBDataType.INT) { + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.INT; + } + + }, + UNARY_MINUS("-", CnosDBDataType.INT) { + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.INT; + } + + }; + + private final String textRepresentation; + private final CnosDBDataType[] dataTypes; + + PrefixOperator(String textRepresentation, CnosDBDataType... dataTypes) { + this.textRepresentation = textRepresentation; + this.dataTypes = dataTypes.clone(); + } + + public abstract CnosDBDataType getExpressionType(); + + @Override + public String getTextRepresentation() { + return toString(); + } + + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBSelect.java b/src/sqlancer/cnosdb/ast/CnosDBSelect.java new file mode 100644 index 000000000..0db657f19 --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBSelect.java @@ -0,0 +1,102 @@ +package sqlancer.cnosdb.ast; + +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.CnosDBSchema.CnosDBTable; +import sqlancer.common.ast.SelectBase; + +public class CnosDBSelect extends SelectBase implements CnosDBExpression { + + private SelectType selectOption = SelectType.ALL; + private List joinClauses = Collections.emptyList(); + private CnosDBExpression distinctOnClause; + + public void setSelectType(SelectType fromOptions) { + this.setSelectOption(fromOptions); + } + + public SelectType getSelectOption() { + return selectOption; + } + + public void setSelectOption(SelectType fromOptions) { + this.selectOption = fromOptions; + } + + @Override + public CnosDBDataType getExpressionType() { + return null; + } + + public List getJoinClauses() { + return joinClauses; + } + + public void setJoinClauses(List joinStatements) { + this.joinClauses = joinStatements; + + } + + public CnosDBExpression getDistinctOnClause() { + return distinctOnClause; + } + + public void setDistinctOnClause(CnosDBExpression distinctOnClause) { + if (selectOption != SelectType.DISTINCT) { + throw new IllegalArgumentException(); + } + this.distinctOnClause = distinctOnClause; + } + + public enum SelectType { + DISTINCT, ALL; + + public static SelectType getRandom() { + return Randomly.fromOptions(values()); + } + } + + public static class CnosDBFromTable implements CnosDBExpression { + private final CnosDBTable t; + + public CnosDBFromTable(CnosDBTable t) { + this.t = t; + } + + public CnosDBTable getTable() { + return t; + } + + @Override + public CnosDBDataType getExpressionType() { + return null; + } + } + + public static class CnosDBSubquery implements CnosDBExpression { + private final CnosDBSelect s; + private final String name; + + public CnosDBSubquery(CnosDBSelect s, String name) { + this.s = s; + this.name = name; + } + + public CnosDBSelect getSelect() { + return s; + } + + public String getName() { + return name; + } + + @Override + public CnosDBDataType getExpressionType() { + return null; + } + } + +} diff --git a/src/sqlancer/cnosdb/ast/CnosDBSimilarTo.java b/src/sqlancer/cnosdb/ast/CnosDBSimilarTo.java new file mode 100644 index 000000000..9e3467ada --- /dev/null +++ b/src/sqlancer/cnosdb/ast/CnosDBSimilarTo.java @@ -0,0 +1,28 @@ +package sqlancer.cnosdb.ast; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public class CnosDBSimilarTo implements CnosDBExpression { + + private final CnosDBExpression string; + private final CnosDBExpression similarTo; + + public CnosDBSimilarTo(CnosDBExpression string, CnosDBExpression similarTo) { + this.string = string; + this.similarTo = similarTo; + } + + public CnosDBExpression getString() { + return string; + } + + public CnosDBExpression getSimilarTo() { + return similarTo; + } + + @Override + public CnosDBDataType getExpressionType() { + return CnosDBDataType.BOOLEAN; + } + +} diff --git a/src/sqlancer/cnosdb/client/CnosDBClient.java b/src/sqlancer/cnosdb/client/CnosDBClient.java new file mode 100644 index 000000000..ccc9dcc16 --- /dev/null +++ b/src/sqlancer/cnosdb/client/CnosDBClient.java @@ -0,0 +1,110 @@ +package sqlancer.cnosdb.client; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.UnsupportedEncodingException; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import org.apache.commons.codec.binary.Base64; +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; + +import com.arangodb.internal.util.IOUtils; + +public class CnosDBClient { + private final String userName; + private final String password; + private final String host; + private final int port; + + private final String database; + private final CloseableHttpClient client; + + public CnosDBClient(String host, int port, String userName, String password, String database) { + this.host = host; + this.port = port; + this.userName = userName; + this.password = password; + this.database = database; + this.client = HttpClientBuilder.create().build(); + } + + private String url() { + return "http://" + host + ":" + port + "/api/v1/"; + } + + public String ping() throws Exception { + HttpGet httpGet = new HttpGet(this.url() + "ping"); + httpGet.setHeader(HttpHeaders.AUTHORIZATION, getAuth()); + CloseableHttpResponse resp = client.execute(httpGet); + + String content = IOUtils.toString(resp.getEntity().getContent()); + resp.close(); + return content; + } + + public CnosDBResultSet executeQuery(String query) throws Exception { + HttpUriRequest request = createRequest(query); + CloseableHttpResponse resp = client.execute(request); + String text = IOUtils.toString(resp.getEntity().getContent()); + if (resp.getStatusLine().getStatusCode() != 200) { + resp.close(); + throw new CnosDBException(database + ":" + query + ";\n" + text); + } + resp.close(); + InputStream stream = new ByteArrayInputStream(text.getBytes(StandardCharsets.UTF_8)); + + return new CnosDBResultSet(new InputStreamReader(stream)); + } + + public boolean execute(String query) throws Exception { + HttpUriRequest request = createRequest(query); + CloseableHttpResponse resp = client.execute(request); + if (resp.getStatusLine().getStatusCode() != 200) { + String res = IOUtils.toString(resp.getEntity().getContent()); + resp.close(); + throw new CnosDBException(query + res); + } + resp.close(); + return true; + } + + public void close() throws IOException { + client.close(); + } + + public String getDatabase() { + return this.database; + } + + private String getAuth() { + String auth = userName + ":" + password; + byte[] encodedAuth = Base64.encodeBase64(auth.getBytes(StandardCharsets.ISO_8859_1)); + return "Basic " + new String(encodedAuth); + + } + + private HttpUriRequest createRequest(String query) throws URISyntaxException, UnsupportedEncodingException { + + URIBuilder builder = new URIBuilder(this.url() + "sql"); + builder.setParameter("db", database); + builder.setParameter("pretty", "true"); + HttpPost httpPost = new HttpPost(builder.build()); + + httpPost.setHeader(HttpHeaders.AUTHORIZATION, getAuth()); + StringEntity stringEntity = new StringEntity(query); + httpPost.setEntity(stringEntity); + return httpPost; + } + +} diff --git a/src/sqlancer/cnosdb/client/CnosDBConnection.java b/src/sqlancer/cnosdb/client/CnosDBConnection.java new file mode 100644 index 000000000..9277f203b --- /dev/null +++ b/src/sqlancer/cnosdb/client/CnosDBConnection.java @@ -0,0 +1,27 @@ +package sqlancer.cnosdb.client; + +import java.io.IOException; + +import sqlancer.SQLancerDBConnection; + +public class CnosDBConnection implements SQLancerDBConnection { + private final CnosDBClient client; + + public CnosDBConnection(CnosDBClient client) { + this.client = client; + } + + @Override + public String getDatabaseVersion() throws Exception { + return client.ping(); + } + + public CnosDBClient getClient() { + return client; + } + + @Override + public void close() throws IOException { + client.close(); + } +} diff --git a/src/sqlancer/cnosdb/client/CnosDBException.java b/src/sqlancer/cnosdb/client/CnosDBException.java new file mode 100644 index 000000000..a1055e90b --- /dev/null +++ b/src/sqlancer/cnosdb/client/CnosDBException.java @@ -0,0 +1,9 @@ +package sqlancer.cnosdb.client; + +public class CnosDBException extends RuntimeException { + private static final long serialVersionUID = 1L; + + CnosDBException(String message) { + super(message); + } +} diff --git a/src/sqlancer/cnosdb/client/CnosDBResultSet.java b/src/sqlancer/cnosdb/client/CnosDBResultSet.java new file mode 100644 index 000000000..877b6ba5d --- /dev/null +++ b/src/sqlancer/cnosdb/client/CnosDBResultSet.java @@ -0,0 +1,52 @@ +package sqlancer.cnosdb.client; + +import java.io.Reader; +import java.sql.SQLException; +import java.util.Iterator; + +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVRecord; + +import sqlancer.IgnoreMeException; + +public class CnosDBResultSet { + private final Iterator records; + private CSVRecord next; + + public CnosDBResultSet(Reader in) throws Exception { + Iterable records = CSVFormat.DEFAULT.builder().setHeader().setSkipHeaderRecord(true).build() + .parse(in); + this.records = records.iterator(); + } + + public void close() { + } + + public boolean next() throws SQLException { + if (records.hasNext()) { + next = records.next(); + return true; + } + return false; + } + + public int getInt(int i) throws SQLException { + return Integer.parseInt(next.get(i - 1)); + } + + public String getString(int i) throws SQLException { + return next.get(i - 1); + } + + public long getLong(int i) throws SQLException { + if (next.get(i - 1).isEmpty()) { + throw new IgnoreMeException(); + } + return Long.parseLong(next.get(i - 1)); + } + + // public boolean getBool(int i) throws Exception { + // return Boolean.parseBoolean(getString(i)); + // } + +} diff --git a/src/sqlancer/cnosdb/gen/CnosDBCommon.java b/src/sqlancer/cnosdb/gen/CnosDBCommon.java new file mode 100644 index 000000000..6c7b0bba7 --- /dev/null +++ b/src/sqlancer/cnosdb/gen/CnosDBCommon.java @@ -0,0 +1,31 @@ +package sqlancer.cnosdb.gen; + +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; + +public final class CnosDBCommon { + + private CnosDBCommon() { + } + + public static void appendDataType(CnosDBDataType type, StringBuilder sb) throws AssertionError { + switch (type) { + case BOOLEAN: + sb.append("BOOLEAN"); + break; + case INT: + sb.append("BIGINT"); + break; + case STRING: + sb.append("STRING"); + break; + case DOUBLE: + sb.append("DOUBLE"); + break; + case UINT: + sb.append("BIGINT UNSIGNED"); + break; + default: + throw new AssertionError(type); + } + } +} diff --git a/src/sqlancer/cnosdb/gen/CnosDBExpressionGenerator.java b/src/sqlancer/cnosdb/gen/CnosDBExpressionGenerator.java new file mode 100644 index 000000000..121f78254 --- /dev/null +++ b/src/sqlancer/cnosdb/gen/CnosDBExpressionGenerator.java @@ -0,0 +1,461 @@ +package sqlancer.cnosdb.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBCompoundDataType; +import sqlancer.cnosdb.CnosDBGlobalState; +import sqlancer.cnosdb.CnosDBSchema.CnosDBColumn; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.ast.CnosDBAggregate; +import sqlancer.cnosdb.ast.CnosDBAggregate.CnosDBAggregateFunction; +import sqlancer.cnosdb.ast.CnosDBBetweenOperation; +import sqlancer.cnosdb.ast.CnosDBBinaryArithmeticOperation; +import sqlancer.cnosdb.ast.CnosDBBinaryArithmeticOperation.CnosDBBinaryOperator; +import sqlancer.cnosdb.ast.CnosDBBinaryComparisonOperation; +import sqlancer.cnosdb.ast.CnosDBBinaryLogicalOperation; +import sqlancer.cnosdb.ast.CnosDBBinaryLogicalOperation.BinaryLogicalOperator; +import sqlancer.cnosdb.ast.CnosDBCastOperation; +import sqlancer.cnosdb.ast.CnosDBColumnValue; +import sqlancer.cnosdb.ast.CnosDBConcatOperation; +import sqlancer.cnosdb.ast.CnosDBConstant; +import sqlancer.cnosdb.ast.CnosDBExpression; +import sqlancer.cnosdb.ast.CnosDBFunction; +import sqlancer.cnosdb.ast.CnosDBFunctionWithUnknownResult; +import sqlancer.cnosdb.ast.CnosDBInOperation; +import sqlancer.cnosdb.ast.CnosDBLikeOperation; +import sqlancer.cnosdb.ast.CnosDBOrderByTerm; +import sqlancer.cnosdb.ast.CnosDBOrderByTerm.CnosDBOrder; +import sqlancer.cnosdb.ast.CnosDBPostfixOperation; +import sqlancer.cnosdb.ast.CnosDBPostfixOperation.PostfixOperator; +import sqlancer.cnosdb.ast.CnosDBPrefixOperation; +import sqlancer.cnosdb.ast.CnosDBPrefixOperation.PrefixOperator; +import sqlancer.cnosdb.ast.CnosDBSimilarTo; +import sqlancer.common.gen.ExpressionGenerator; + +public class CnosDBExpressionGenerator implements ExpressionGenerator { + + private final int maxDepth; + + private final Randomly r; + + private List columns; + + private boolean allowAggregateFunctions; + + public CnosDBExpressionGenerator(CnosDBGlobalState globalState) { + this.r = globalState.getRandomly(); + this.maxDepth = globalState.getOptions().getMaxExpressionDepth(); + } + + public static CnosDBExpression generateExpression(CnosDBGlobalState globalState, CnosDBDataType type) { + return new CnosDBExpressionGenerator(globalState).generateExpression(0, type); + } + + private static CnosDBCompoundDataType getCompoundDataType(CnosDBDataType type) { + return CnosDBCompoundDataType.create(type); + } + + public static CnosDBExpression generateConstant(Randomly r, CnosDBDataType type) { + if (Randomly.getBooleanWithSmallProbability()) { + return CnosDBConstant.createNullConstant(); + } + switch (type) { + case INT: + return CnosDBConstant.createIntConstant(r.getInteger()); + case UINT: + return CnosDBConstant.createUintConstant(r.getPositiveInteger()); + case TIMESTAMP: + return CnosDBConstant.createTimeStampConstant(r.getPositiveIntegerNotNull()); + case BOOLEAN: + return CnosDBConstant.createBooleanConstant(Randomly.getBoolean()); + case STRING: + return CnosDBConstant.createStringConstant(r.getString()); + case DOUBLE: + return CnosDBConstant.createDoubleConstant(r.getDouble()); + default: + throw new AssertionError(type); + } + } + + public static CnosDBExpression generateExpression(CnosDBGlobalState globalState, List columns, + CnosDBDataType type) { + return new CnosDBExpressionGenerator(globalState).setColumns(columns).generateExpression(0, type); + } + + public static CnosDBExpression generateExpression(CnosDBGlobalState globalState, List columns) { + return new CnosDBExpressionGenerator(globalState).setColumns(columns).generateExpression(0); + } + + public CnosDBExpressionGenerator setColumns(List columns) { + this.columns = columns; + return this; + } + + public CnosDBExpression generateExpression(int depth) { + return generateExpression(depth, CnosDBDataType.getRandomType()); + } + + public List generateOrderBy() { + List orderBys = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber(); i++) { + orderBys.add(new CnosDBOrderByTerm(CnosDBColumnValue.create(Randomly.fromList(columns)), + CnosDBOrder.getRandomOrder())); + } + return orderBys; + } + + private CnosDBExpression generateFunctionWithUnknownResult(int depth, CnosDBDataType type) { + List supportedFunctions = CnosDBFunctionWithUnknownResult + .getSupportedFunctions(type); + if (supportedFunctions.isEmpty()) { + throw new IgnoreMeException(); + } + CnosDBFunctionWithUnknownResult randomFunction = Randomly.fromList(supportedFunctions); + return new CnosDBFunction(randomFunction, type, randomFunction.getArguments(type, this, depth + 1)); + } + + private CnosDBExpression generateBooleanExpression(int depth) { + List validOptions = new ArrayList<>(Arrays.asList(BooleanExpression.values())); + BooleanExpression option = Randomly.fromList(validOptions); + switch (option) { + case POSTFIX_OPERATOR: + PostfixOperator random = PostfixOperator.getRandom(); + return CnosDBPostfixOperation + .create(generateExpression(depth + 1, Randomly.fromOptions(random.getInputDataTypes())), random); + case IN_OPERATION: + return inOperation(depth + 1); + case NOT: + return new CnosDBPrefixOperation(generateExpression(depth + 1, CnosDBDataType.BOOLEAN), PrefixOperator.NOT); + case BINARY_LOGICAL_OPERATOR: + CnosDBExpression first = generateExpression(depth + 1, CnosDBDataType.BOOLEAN); + int nr = Randomly.smallNumber() + 1; + for (int i = 0; i < nr; i++) { + first = new CnosDBBinaryLogicalOperation(first, generateExpression(depth + 1, CnosDBDataType.BOOLEAN), + BinaryLogicalOperator.getRandom()); + } + return first; + case BINARY_COMPARISON: + CnosDBDataType dataType = getMeaningfulType(); + return generateComparison(depth, dataType); + case CAST: + return generateCastExpression(depth + 1, CnosDBDataType.BOOLEAN); + case FUNCTION: + return generateFunction(depth + 1, CnosDBDataType.BOOLEAN); + case LIKE: + return new CnosDBLikeOperation(generateExpression(depth + 1, CnosDBDataType.STRING), + generateExpression(depth + 1, CnosDBDataType.STRING)); + case BETWEEN: + CnosDBDataType type = getMeaningfulType(); + return new CnosDBBetweenOperation(generateExpression(depth + 1, type), generateExpression(depth + 1, type), + generateExpression(depth + 1, type)); + case SIMILAR_TO: + return new CnosDBSimilarTo(generateExpression(depth + 1, CnosDBDataType.STRING), + generateExpression(depth + 1, CnosDBDataType.STRING)); + default: + throw new AssertionError(); + } + } + + private CnosDBDataType getMeaningfulType() { + // make it more likely that the expression does not only consist of constant + // expressions + if (Randomly.getBooleanWithSmallProbability() || columns == null || columns.isEmpty()) { + return CnosDBDataType.getRandomType(); + } else { + return Randomly.fromList(columns).getType(); + } + } + + private CnosDBExpression generateFunction(int depth, CnosDBDataType type) { + return generateFunctionWithUnknownResult(depth, type); + } + + private CnosDBExpression generateComparison(int depth, CnosDBDataType dataType) { + CnosDBExpression leftExpr = generateExpression(depth + 1, dataType); + CnosDBExpression rightExpr = generateExpression(depth + 1, dataType); + return getComparison(leftExpr, rightExpr); + } + + private CnosDBExpression getComparison(CnosDBExpression leftExpr, CnosDBExpression rightExpr) { + return new CnosDBBinaryComparisonOperation(leftExpr, rightExpr, + CnosDBBinaryComparisonOperation.CnosDBBinaryComparisonOperator.getRandom()); + } + + private CnosDBExpression inOperation(int depth) { + CnosDBDataType type = CnosDBDataType.getRandomType(); + CnosDBExpression leftExpr = generateExpression(depth + 1, type); + List rightExpr = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + rightExpr.add(generateConstant(new Randomly(), type)); + } + return new CnosDBInOperation(leftExpr, rightExpr, Randomly.getBoolean()); + } + + public CnosDBExpression generateExpression(int depth, CnosDBDataType originalType) { + return generateExpressionInternal(depth, originalType); + } + + private CnosDBExpression generateExpressionInternal(int depth, CnosDBDataType dataType) throws AssertionError { + if (allowAggregateFunctions && Randomly.getBoolean()) { + return getAggregate(dataType); + } + + if (Randomly.getBooleanWithRatherLowProbability() || depth > maxDepth) { + // generic expression + if (Randomly.getBoolean() || depth > maxDepth) { + if (Randomly.getBooleanWithRatherLowProbability()) { + return generateConstant(r, dataType); + } else { + if (filterColumns(dataType).isEmpty()) { + return generateConstant(r, dataType); + } else { + return createColumnOfType(dataType); + } + } + } else { + if (Randomly.getBoolean()) { + return generateCastExpression(depth + 1, dataType); + } else { + return generateFunctionWithUnknownResult(depth, dataType); + } + } + } else { + switch (dataType) { + case BOOLEAN: + return generateBooleanExpression(depth); + case INT: + return generateIntExpression(depth); + case UINT: + return generateUIntExpression(depth); + case STRING: + return generateStringExpression(depth); + case DOUBLE: + return generateFloatExpression(depth); + case TIMESTAMP: + return generateTimeStampExpression(depth); + default: + throw new AssertionError(dataType); + } + } + } + + private CnosDBExpression generateStringExpression(int depth) { + StringExpression option; + List validOptions = new ArrayList<>(Arrays.asList(StringExpression.values())); + option = Randomly.fromList(validOptions); + + switch (option) { + case CAST: + return generateCastExpression(depth + 1, CnosDBDataType.STRING); + case FUNCTION: + return generateFunction(depth + 1, CnosDBDataType.STRING); + case CONCAT: + return generateConcat(depth); + default: + throw new AssertionError(); + } + } + + private CnosDBExpression generateConcat(int depth) { + CnosDBExpression left = generateExpression(depth + 1, CnosDBDataType.STRING); + CnosDBExpression right = generateExpression(depth + 1); + return new CnosDBConcatOperation(left, right); + } + + private CnosDBExpression generateIntExpression(int depth) { + IntExpression option; + option = Randomly.fromOptions(IntExpression.values()); + switch (option) { + case CAST: + return generateCastExpression(depth + 1, CnosDBDataType.INT); + case UNARY_OPERATION: + CnosDBExpression intExpression = generateExpression(depth + 1, CnosDBDataType.INT); + return new CnosDBPrefixOperation(intExpression, + Randomly.getBoolean() ? PrefixOperator.UNARY_PLUS : PrefixOperator.UNARY_MINUS); + case FUNCTION: + return generateFunction(depth + 1, CnosDBDataType.INT); + case BINARY_ARITHMETIC_EXPRESSION: + return new CnosDBBinaryArithmeticOperation(generateExpression(depth + 1, CnosDBDataType.INT), + generateExpression(depth + 1, CnosDBDataType.INT), + CnosDBBinaryOperator.getRandom(CnosDBDataType.INT)); + default: + throw new AssertionError(); + } + } + + private CnosDBExpression generateUIntExpression(int depth) { + UIntExpression option = Randomly.fromOptions(UIntExpression.values()); + switch (option) { + case CAST: + return generateCastExpression(depth + 1, CnosDBDataType.UINT); + case FUNCTION: + return generateFunction(depth + 1, CnosDBDataType.UINT); + case BINARY_ARITHMETIC_EXPRESSION: + return new CnosDBBinaryArithmeticOperation(generateExpression(depth + 1, CnosDBDataType.UINT), + generateExpression(depth + 1, CnosDBDataType.UINT), + CnosDBBinaryOperator.getRandom(CnosDBDataType.UINT)); + default: + throw new AssertionError(); + } + + } + + private CnosDBExpression generateFloatExpression(int depth) { + FloatExpression option; + option = Randomly.fromOptions(FloatExpression.values()); + switch (option) { + case CAST: + return generateCastExpression(depth + 1, CnosDBDataType.DOUBLE); + case UNARY_OPERATION: + CnosDBExpression floatExpression = generateExpression(depth + 1, CnosDBDataType.DOUBLE); + return new CnosDBPrefixOperation(floatExpression, + Randomly.getBoolean() ? PrefixOperator.UNARY_PLUS : PrefixOperator.UNARY_MINUS); + case FUNCTION: + return generateFunction(depth + 1, CnosDBDataType.DOUBLE); + case BINARY_ARITHMETIC_EXPRESSION: + return new CnosDBBinaryArithmeticOperation(generateExpression(depth + 1, CnosDBDataType.DOUBLE), + generateExpression(depth + 1, CnosDBDataType.DOUBLE), + CnosDBBinaryOperator.getRandom(CnosDBDataType.DOUBLE)); + case CONSTANT: + return generateConstant(r, CnosDBDataType.DOUBLE); + default: + throw new AssertionError(); + } + } + + private CnosDBExpression generateTimeStampExpression(int depth) { + if (Randomly.getBoolean()) { + return generateConstant(r, CnosDBDataType.TIMESTAMP); + } + TimestampExpression option; + option = Randomly.fromOptions(TimestampExpression.values()); + switch (option) { + case CAST: + return generateCastExpression(depth + 1, CnosDBDataType.TIMESTAMP); + case FUNCTION: + return generateFunction(depth + 1, CnosDBDataType.TIMESTAMP); + default: + throw new AssertionError(); + } + } + + private CnosDBExpression generateCastExpression(int depth, CnosDBDataType castToType) { + CnosDBDataType castFromType = Randomly.fromList(CnosDBCastOperation.canCastTo(castToType)); + return new CnosDBCastOperation(generateExpression(depth + 1, castFromType), getCompoundDataType(castToType)); + } + + private CnosDBExpression createColumnOfType(CnosDBDataType type) { + List columns = filterColumns(type); + if (columns.isEmpty()) { + throw new IgnoreMeException(); + } + CnosDBColumn fromList = Randomly.fromList(columns); + return CnosDBColumnValue.create(fromList); + } + + final List filterColumns(CnosDBDataType type) { + if (columns == null) { + return Collections.emptyList(); + } else { + return columns.stream().filter(c -> c.getType() == type).collect(Collectors.toList()); + } + } + + public List generateExpressions(int nr) { + List expressions = new ArrayList<>(); + for (int i = 0; i < nr; i++) { + expressions.add(generateExpression(0)); + } + return expressions; + } + + public CnosDBExpression generateExpression(CnosDBDataType dataType) { + return generateExpression(0, dataType); + } + + public CnosDBExpression generateHavingClause() { + this.allowAggregateFunctions = true; + CnosDBExpression expression = generateExpression(CnosDBDataType.BOOLEAN); + this.allowAggregateFunctions = false; + return expression; + } + + public CnosDBExpression generateAggregate() { + return getAggregate(CnosDBDataType.getRandomType()); + } + + private CnosDBExpression getAggregate(CnosDBDataType dataType) { + if (dataType == CnosDBDataType.BOOLEAN) { + List aggregates = CnosDBAggregateFunction.getAggregates(CnosDBDataType.INT); + CnosDBAggregateFunction agg = Randomly.fromList(aggregates); + return new CnosDBCastOperation(generateArgsForAggregate(dataType, agg), + CnosDBCompoundDataType.create(CnosDBDataType.BOOLEAN)); + } else { + List aggregates = CnosDBAggregateFunction.getAggregates(dataType); + CnosDBAggregateFunction agg = Randomly.fromList(aggregates); + return generateArgsForAggregate(dataType, agg); + } + } + + public CnosDBAggregate generateArgsForAggregate(CnosDBDataType dataType, CnosDBAggregateFunction agg) { + CnosDBDataType[] types = agg.getArgsTypes(dataType); + List args = new ArrayList<>(); + for (CnosDBDataType argType : types) { + args.add(createColumnOfType(argType)); + // args.add(generateExpression(argType)); + } + return new CnosDBAggregate(args, agg); + } + + public CnosDBExpressionGenerator allowAggregates(boolean value) { + allowAggregateFunctions = value; + return this; + } + + @Override + public CnosDBExpression generatePredicate() { + return generateExpression(CnosDBDataType.BOOLEAN); + } + + @Override + public CnosDBExpression negatePredicate(CnosDBExpression predicate) { + return new CnosDBPrefixOperation(predicate, PrefixOperator.NOT); + } + + @Override + public CnosDBExpression isNull(CnosDBExpression expr) { + return new CnosDBPostfixOperation(expr, PostfixOperator.IS_NULL); + } + + private enum BooleanExpression { + POSTFIX_OPERATOR, NOT, BINARY_LOGICAL_OPERATOR, BINARY_COMPARISON, FUNCTION, CAST, LIKE, BETWEEN, IN_OPERATION, + SIMILAR_TO, + } + + private enum StringExpression { + CAST, FUNCTION, CONCAT + } + + private enum IntExpression { + UNARY_OPERATION, FUNCTION, CAST, BINARY_ARITHMETIC_EXPRESSION + } + + private enum UIntExpression { + FUNCTION, CAST, BINARY_ARITHMETIC_EXPRESSION + } + + private enum FloatExpression { + UNARY_OPERATION, FUNCTION, CAST, BINARY_ARITHMETIC_EXPRESSION, CONSTANT + } + + private enum TimestampExpression { + FUNCTION, CAST + } + +} diff --git a/src/sqlancer/cnosdb/gen/CnosDBInsertGenerator.java b/src/sqlancer/cnosdb/gen/CnosDBInsertGenerator.java new file mode 100644 index 000000000..0d575d3c7 --- /dev/null +++ b/src/sqlancer/cnosdb/gen/CnosDBInsertGenerator.java @@ -0,0 +1,59 @@ +package sqlancer.cnosdb.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBGlobalState; +import sqlancer.cnosdb.CnosDBSchema.CnosDBColumn; +import sqlancer.cnosdb.CnosDBSchema.CnosDBTable; +import sqlancer.cnosdb.CnosDBVisitor; +import sqlancer.cnosdb.ast.CnosDBExpression; +import sqlancer.cnosdb.query.CnosDBOtherQuery; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.schema.AbstractTableColumn; + +public final class CnosDBInsertGenerator { + + private CnosDBInsertGenerator() { + } + + public static CnosDBOtherQuery insert(CnosDBGlobalState globalState) { + CnosDBTable table = globalState.getSchema().getRandomTable(); + ExpectedErrors errors = new ExpectedErrors(); + errors.add("Column time cannot be null."); + StringBuilder sb = new StringBuilder(); + sb.append("INSERT "); + sb.append(table.getName()); + List columns = table.getRandomNonEmptyColumnSubset(); + sb.append("("); + sb.append(columns.stream().map(AbstractTableColumn::getName).collect(Collectors.joining(", "))); + sb.append(")"); + sb.append(" VALUES"); + + int n = Randomly.smallNumber() + 1; + for (int i = 0; i < n; i++) { + if (i != 0) { + sb.append(", "); + } + insertRow(globalState, sb, columns); + } + + // error + return new CnosDBOtherQuery(sb.toString(), errors); + } + + private static void insertRow(CnosDBGlobalState globalState, StringBuilder sb, List columns) { + sb.append("("); + for (int i = 0; i < columns.size(); i++) { + if (i > 0) { + sb.append(", "); + } + CnosDBExpression generateConstant = CnosDBExpressionGenerator.generateConstant(globalState.getRandomly(), + columns.get(i).getType()); + sb.append(CnosDBVisitor.asString(generateConstant)); + } + sb.append(")"); + } + +} diff --git a/src/sqlancer/cnosdb/gen/CnosDBTableGenerator.java b/src/sqlancer/cnosdb/gen/CnosDBTableGenerator.java new file mode 100644 index 000000000..c046ad3e9 --- /dev/null +++ b/src/sqlancer/cnosdb/gen/CnosDBTableGenerator.java @@ -0,0 +1,77 @@ +package sqlancer.cnosdb.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBSchema.CnosDBColumn; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.CnosDBSchema.CnosDBFieldColumn; +import sqlancer.cnosdb.CnosDBSchema.CnosDBTable; +import sqlancer.cnosdb.CnosDBSchema.CnosDBTagColumn; +import sqlancer.cnosdb.query.CnosDBOtherQuery; +import sqlancer.common.query.ExpectedErrors; + +public class CnosDBTableGenerator { + + protected final ExpectedErrors errors = new ExpectedErrors(); + private final String tableName; + private final StringBuilder sb = new StringBuilder(); + private final List columnsToBeAdd = new ArrayList<>(); + private CnosDBTable table; + + public CnosDBTableGenerator(String tableName) { + this.tableName = tableName; + } + + public static CnosDBOtherQuery generate(String tableName) { + return new CnosDBTableGenerator(tableName).generate(); + } + + protected CnosDBOtherQuery generate() { + table = new CnosDBTable(tableName, columnsToBeAdd); + + sb.append("CREATE TABLE"); + if (Randomly.getBoolean()) { + sb.append(" IF NOT EXISTS"); + } + sb.append(" "); + sb.append(tableName); + + sb.append("("); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + String name = String.format("f%d", i); + createField(name); + sb.append(", "); + } + + sb.append("TAGS("); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + String name = String.format("t%d", i); + createTag(name); + } + sb.append("))"); + return new CnosDBOtherQuery(sb.toString(), new ExpectedErrors()); + } + + private void createField(String name) throws AssertionError { + sb.append(name); + sb.append(" "); + CnosDBDataType type = CnosDBDataType.getRandomTypeWithoutTimeStamp(); + CnosDBCommon.appendDataType(type, sb); + CnosDBFieldColumn c = new CnosDBFieldColumn(name, type); + c.setTable(table); + sb.append(" "); + columnsToBeAdd.add(c); + } + + private void createTag(String name) { + sb.append(name); + CnosDBColumn column = new CnosDBTagColumn(name); + column.setTable(table); + columnsToBeAdd.add(column); + } +} diff --git a/src/sqlancer/cnosdb/oracle/CnosDBNoRECBase.java b/src/sqlancer/cnosdb/oracle/CnosDBNoRECBase.java new file mode 100644 index 000000000..472aa8f66 --- /dev/null +++ b/src/sqlancer/cnosdb/oracle/CnosDBNoRECBase.java @@ -0,0 +1,23 @@ +package sqlancer.cnosdb.oracle; + +import sqlancer.Main; +import sqlancer.MainOptions; +import sqlancer.cnosdb.CnosDBGlobalState; +import sqlancer.cnosdb.client.CnosDBConnection; +import sqlancer.common.oracle.TestOracle; + +public abstract class CnosDBNoRECBase implements TestOracle { + protected final CnosDBGlobalState state; + protected final Main.StateLogger logger; + protected final MainOptions options; + protected final CnosDBConnection con; + protected String optimizedQueryString; + protected String unoptimizedQueryString; + + public CnosDBNoRECBase(CnosDBGlobalState state) { + this.state = state; + this.con = state.getConnection(); + this.logger = state.getLogger(); + this.options = state.getOptions(); + } +} diff --git a/src/sqlancer/cnosdb/oracle/CnosDBNoRECOracle.java b/src/sqlancer/cnosdb/oracle/CnosDBNoRECOracle.java new file mode 100644 index 000000000..0c817c655 --- /dev/null +++ b/src/sqlancer/cnosdb/oracle/CnosDBNoRECOracle.java @@ -0,0 +1,171 @@ +package sqlancer.cnosdb.oracle; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBCompoundDataType; +import sqlancer.cnosdb.CnosDBExpectedError; +import sqlancer.cnosdb.CnosDBGlobalState; +import sqlancer.cnosdb.CnosDBSchema; +import sqlancer.cnosdb.CnosDBSchema.CnosDBColumn; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.CnosDBSchema.CnosDBTable; +import sqlancer.cnosdb.CnosDBSchema.CnosDBTables; +import sqlancer.cnosdb.CnosDBVisitor; +import sqlancer.cnosdb.ast.CnosDBCastOperation; +import sqlancer.cnosdb.ast.CnosDBColumnValue; +import sqlancer.cnosdb.ast.CnosDBExpression; +import sqlancer.cnosdb.ast.CnosDBJoin; +import sqlancer.cnosdb.ast.CnosDBJoin.CnosDBJoinType; +import sqlancer.cnosdb.ast.CnosDBPostfixText; +import sqlancer.cnosdb.ast.CnosDBSelect; +import sqlancer.cnosdb.ast.CnosDBSelect.CnosDBFromTable; +import sqlancer.cnosdb.ast.CnosDBSelect.CnosDBSubquery; +import sqlancer.cnosdb.ast.CnosDBSelect.SelectType; +import sqlancer.cnosdb.client.CnosDBResultSet; +import sqlancer.cnosdb.gen.CnosDBExpressionGenerator; +import sqlancer.cnosdb.oracle.tlp.CnosDBTLPBase; +import sqlancer.cnosdb.query.CnosDBSelectQuery; +import sqlancer.common.oracle.TestOracle; + +public class CnosDBNoRECOracle extends CnosDBNoRECBase implements TestOracle { + + private final CnosDBSchema s; + + public CnosDBNoRECOracle(CnosDBGlobalState globalState) { + super(globalState); + this.s = globalState.getSchema(); + } + + public static List getJoinStatements(CnosDBGlobalState globalState, List columns, + List tables) { + List joinStatements = new ArrayList<>(); + CnosDBExpressionGenerator gen = new CnosDBExpressionGenerator(globalState).setColumns(columns); + for (int i = 1; i < tables.size(); i++) { + CnosDBExpression joinClause = gen.generateExpression(CnosDBDataType.BOOLEAN); + CnosDBTable table = Randomly.fromList(tables); + tables.remove(table); + CnosDBJoinType options = CnosDBJoinType.getRandom(); + CnosDBJoin j = new CnosDBJoin(new CnosDBFromTable(table), joinClause, options); + joinStatements.add(j); + } + // JOIN subqueries + for (int i = 0; i < Randomly.smallNumber(); i++) { + CnosDBTables subqueryTables = globalState.getSchema().getRandomTableNonEmptyTables(); + CnosDBSubquery subquery = CnosDBTLPBase.createSubquery(globalState, String.format("sub%d", i), + subqueryTables); + CnosDBExpression joinClause = gen.generateExpression(CnosDBDataType.BOOLEAN); + CnosDBJoinType options = CnosDBJoinType.getRandom(); + CnosDBJoin j = new CnosDBJoin(subquery, joinClause, options); + joinStatements.add(j); + } + return joinStatements; + } + + @Override + public void check() throws Exception { + CnosDBTables randomTables = s.getRandomTableNonEmptyTables(); + List columns = randomTables.getColumns(); + CnosDBExpression randomWhereCondition = getRandomWhereCondition(columns); + List tables = randomTables.getTables(); + + List joinStatements = getJoinStatements(state, columns, tables); + List fromTables = tables.stream().map(CnosDBFromTable::new).collect(Collectors.toList()); + int secondCount = getUnoptimizedQueryCount(fromTables, randomWhereCondition, joinStatements); + int firstCount = getOptimizedQueryCount(fromTables, List.of(CnosDBColumn.createDummy("f0")), + randomWhereCondition, joinStatements); + if (firstCount == -1 || secondCount == -1) { + throw new IgnoreMeException(); + } + if (firstCount != secondCount) { + String queryFormatString = "-- %s;\n-- count: %d"; + String firstQueryStringWithCount = String.format(queryFormatString, optimizedQueryString, firstCount); + String secondQueryStringWithCount = String.format(queryFormatString, unoptimizedQueryString, secondCount); + state.getState().getLocalState() + .log(String.format("%s\n%s", firstQueryStringWithCount, secondQueryStringWithCount)); + String assertionMessage = String.format("the counts mismatch (%d and %d)!\n%s\n%s", firstCount, secondCount, + firstQueryStringWithCount, secondQueryStringWithCount); + throw new AssertionError(assertionMessage); + } + } + + private CnosDBExpression getRandomWhereCondition(List columns) { + return new CnosDBExpressionGenerator(state).setColumns(columns).generateExpression(CnosDBDataType.BOOLEAN); + } + + private int getUnoptimizedQueryCount(List fromTables, CnosDBExpression randomWhereCondition, + List joinStatements) throws Exception { + CnosDBSelect select = new CnosDBSelect(); + CnosDBCastOperation isTrue = new CnosDBCastOperation(randomWhereCondition, + CnosDBCompoundDataType.create(CnosDBDataType.INT)); + CnosDBPostfixText asText = new CnosDBPostfixText(isTrue, " as count", CnosDBDataType.INT); + select.setFetchColumns(List.of(asText)); + select.setFromList(fromTables); + select.setSelectType(SelectType.ALL); + select.setJoinClauses(joinStatements); + int secondCount = 0; + unoptimizedQueryString = "SELECT SUM(count) FROM (" + CnosDBVisitor.asString(select) + ") as res"; + if (options.logEachSelect()) { + logger.writeCurrent(unoptimizedQueryString); + } + CnosDBSelectQuery q = new CnosDBSelectQuery(unoptimizedQueryString, CnosDBExpectedError.expectedErrors()); + CnosDBResultSet rs; + try { + q.executeAndGet(state); + rs = q.getResultSet(); + } catch (Exception e) { + if (q.getExpectedErrors().errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } + throw new AssertionError(unoptimizedQueryString, e); + } + if (rs == null) { + return -1; + } + + if (rs.next()) { + secondCount += rs.getLong(1); + } + rs.close(); + return secondCount; + } + + private int getOptimizedQueryCount(List randomTables, List columns, + CnosDBExpression randomWhereCondition, List joinStatements) { + CnosDBSelect select = new CnosDBSelect(); + CnosDBColumnValue allColumns = new CnosDBColumnValue(Randomly.fromList(columns)); + select.setFetchColumns(List.of(allColumns)); + select.setFromList(randomTables); + select.setWhereClause(randomWhereCondition); + if (Randomly.getBooleanWithSmallProbability()) { + select.setOrderByClauses(new CnosDBExpressionGenerator(state).setColumns(columns).generateOrderBy()); + } + select.setSelectType(SelectType.ALL); + select.setJoinClauses(joinStatements); + int firstCount = 0; + optimizedQueryString = CnosDBVisitor.asString(select); + if (options.logEachSelect()) { + logger.writeCurrent(optimizedQueryString); + } + CnosDBSelectQuery query = new CnosDBSelectQuery(optimizedQueryString, CnosDBExpectedError.expectedErrors()); + CnosDBResultSet rs; + try { + query.executeAndGet(state); + rs = query.getResultSet(); + while (rs.next()) { + firstCount++; + } + } catch (Exception e) { + if (query.getExpectedErrors().errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } + + throw new IgnoreMeException(); + } + return firstCount; + } + +} diff --git a/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPAggregateOracle.java b/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPAggregateOracle.java new file mode 100644 index 000000000..b51624a94 --- /dev/null +++ b/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPAggregateOracle.java @@ -0,0 +1,176 @@ +package sqlancer.cnosdb.oracle.tlp; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import sqlancer.ComparatorHelper; +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBExpectedError; +import sqlancer.cnosdb.CnosDBGlobalState; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.CnosDBVisitor; +import sqlancer.cnosdb.ast.CnosDBAggregate; +import sqlancer.cnosdb.ast.CnosDBAggregate.CnosDBAggregateFunction; +import sqlancer.cnosdb.ast.CnosDBAlias; +import sqlancer.cnosdb.ast.CnosDBExpression; +import sqlancer.cnosdb.ast.CnosDBJoin; +import sqlancer.cnosdb.ast.CnosDBPostfixOperation; +import sqlancer.cnosdb.ast.CnosDBPostfixOperation.PostfixOperator; +import sqlancer.cnosdb.ast.CnosDBPrefixOperation; +import sqlancer.cnosdb.ast.CnosDBPrefixOperation.PrefixOperator; +import sqlancer.cnosdb.ast.CnosDBSelect; +import sqlancer.cnosdb.client.CnosDBResultSet; +import sqlancer.cnosdb.query.CnosDBSelectQuery; +import sqlancer.common.oracle.TestOracle; + +public class CnosDBTLPAggregateOracle extends CnosDBTLPBase implements TestOracle { + + private String firstResult; + private String secondResult; + private String originalQuery; + private String metamorphicQuery; + + public CnosDBTLPAggregateOracle(CnosDBGlobalState state) { + super(state); + } + + @Override + public void check() throws Exception { + super.check(); + aggregateCheck(); + } + + protected void aggregateCheck() { + CnosDBAggregateFunction aggregateFunction = Randomly.fromOptions(CnosDBAggregateFunction.MAX, + CnosDBAggregateFunction.MIN, CnosDBAggregateFunction.SUM); + + CnosDBAggregate aggregate = gen.generateArgsForAggregate(aggregateFunction.getRandomReturnType(), + aggregateFunction); + List fetchColumns = new ArrayList<>(); + fetchColumns.add(aggregate); + while (Randomly.getBooleanWithRatherLowProbability()) { + fetchColumns.add(gen.generateAggregate()); + } + select.setFetchColumns(fetchColumns); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBy()); + } + originalQuery = CnosDBVisitor.asString(select); + firstResult = getAggregateResult(originalQuery); + metamorphicQuery = createMetamorphicUnionQuery(select, aggregate, select.getFromList()); + secondResult = getAggregateResult(metamorphicQuery); + + String queryFormatString = "-- %s;\n-- result: %s"; + String firstQueryString = String.format(queryFormatString, originalQuery, firstResult); + String secondQueryString = String.format(queryFormatString, metamorphicQuery, secondResult); + state.getState().getLocalState().log(String.format("%s\n%s", firstQueryString, secondQueryString)); + if (firstResult == null && secondResult != null || firstResult != null && secondResult == null + || firstResult != null && !firstResult.contentEquals(secondResult) + && !ComparatorHelper.isEqualDouble(firstResult, secondResult)) { + if (secondResult != null && secondResult.contains("Inf")) { + throw new IgnoreMeException(); // FIXME: average computation + } + String assertionMessage = String.format("%s: the results mismatch!\n%s\n%s", this.s.getDatabaseName(), + firstQueryString, secondQueryString); + throw new AssertionError(assertionMessage); + } + } + + private String createMetamorphicUnionQuery(CnosDBSelect select, CnosDBAggregate aggregate, + List from) { + String metamorphicQuery; + CnosDBExpression whereClause = gen.generateExpression(CnosDBDataType.BOOLEAN); + CnosDBExpression negatedClause = new CnosDBPrefixOperation(whereClause, PrefixOperator.NOT); + CnosDBExpression notNullClause = new CnosDBPostfixOperation(whereClause, PostfixOperator.IS_NULL); + List mappedAggregate = mapped(aggregate); + CnosDBSelect leftSelect = getSelect(mappedAggregate, from, whereClause, select.getJoinClauses()); + CnosDBSelect middleSelect = getSelect(mappedAggregate, from, negatedClause, select.getJoinClauses()); + CnosDBSelect rightSelect = getSelect(mappedAggregate, from, notNullClause, select.getJoinClauses()); + metamorphicQuery = "SELECT " + getOuterAggregateFunction(aggregate) + " FROM ("; + metamorphicQuery += CnosDBVisitor.asString(leftSelect) + " UNION ALL " + CnosDBVisitor.asString(middleSelect) + + " UNION ALL " + CnosDBVisitor.asString(rightSelect); + metamorphicQuery += ") as asdf"; + return metamorphicQuery; + } + + private String getAggregateResult(String queryString) { + // log TLP Aggregate SELECT queries on the current log file + if (state.getOptions().logEachSelect()) { + // TODO: refactor me + state.getLogger().writeCurrent(queryString); + try { + state.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + String resultString = null; + + CnosDBSelectQuery q = new CnosDBSelectQuery(queryString, CnosDBExpectedError.expectedErrors()); + try { + q.executeAndGet(state); + CnosDBResultSet result = q.getResultSet(); + + if (result == null || !result.next()) { + throw new IgnoreMeException(); + } + + resultString = result.getString(1); + + } catch (Exception e) { + if (q.getExpectedErrors().errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } + } + + return resultString; + } + + private List mapped(CnosDBAggregate aggregate) { + switch (aggregate.getFunction()) { + case SUM: + case MAX: + case MIN: + return aliasArgs(List.of(aggregate)); + // now not support + // case COUNT: + // case AVG: + default: + throw new AssertionError(aggregate.getFunction()); + } + } + + private List aliasArgs(List originalAggregateArgs) { + List args = new ArrayList<>(); + int i = 0; + for (CnosDBExpression expr : originalAggregateArgs) { + args.add(new CnosDBAlias(expr, "agg" + i++)); + } + return args; + } + + private String getOuterAggregateFunction(CnosDBAggregate aggregate) { + if (Objects.requireNonNull(aggregate.getFunction()) == CnosDBAggregateFunction.COUNT) { + return CnosDBAggregateFunction.SUM + "(agg0)"; + } + return aggregate.getFunction() + "(agg0)"; + } + + private CnosDBSelect getSelect(List aggregates, List from, + CnosDBExpression whereClause, List joinList) { + CnosDBSelect leftSelect = new CnosDBSelect(); + leftSelect.setFetchColumns(aggregates); + leftSelect.setFromList(from); + leftSelect.setWhereClause(whereClause); + leftSelect.setJoinClauses(joinList); + if (Randomly.getBooleanWithSmallProbability()) { + leftSelect.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + } + return leftSelect; + } + +} diff --git a/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPBase.java b/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPBase.java new file mode 100644 index 000000000..bd7ba3b55 --- /dev/null +++ b/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPBase.java @@ -0,0 +1,112 @@ +package sqlancer.cnosdb.oracle.tlp; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBGlobalState; +import sqlancer.cnosdb.CnosDBSchema; +import sqlancer.cnosdb.CnosDBSchema.CnosDBColumn; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.CnosDBSchema.CnosDBTable; +import sqlancer.cnosdb.CnosDBSchema.CnosDBTables; +import sqlancer.cnosdb.ast.CnosDBColumnValue; +import sqlancer.cnosdb.ast.CnosDBConstant; +import sqlancer.cnosdb.ast.CnosDBExpression; +import sqlancer.cnosdb.ast.CnosDBJoin; +import sqlancer.cnosdb.ast.CnosDBSelect; +import sqlancer.cnosdb.ast.CnosDBSelect.CnosDBFromTable; +import sqlancer.cnosdb.ast.CnosDBSelect.CnosDBSubquery; +import sqlancer.cnosdb.gen.CnosDBExpressionGenerator; +import sqlancer.cnosdb.oracle.CnosDBNoRECOracle; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; +import sqlancer.common.oracle.TestOracle; + +public class CnosDBTLPBase extends TernaryLogicPartitioningOracleBase + implements TestOracle { + + protected CnosDBSchema s; + protected CnosDBTables targetTables; + protected CnosDBExpressionGenerator gen; + protected CnosDBSelect select; + + public CnosDBTLPBase(CnosDBGlobalState state) { + super(state); + } + + public static CnosDBSubquery createSubquery(CnosDBGlobalState globalState, String name, CnosDBTables tables) { + List columns = new ArrayList<>(); + CnosDBExpressionGenerator gen = new CnosDBExpressionGenerator(globalState).setColumns(tables.getColumns()); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + columns.add(gen.generateExpression(0)); + } + CnosDBSelect select = new CnosDBSelect(); + select.setFromList(tables.getTables().stream().map(CnosDBFromTable::new).collect(Collectors.toList())); + select.setFetchColumns(columns); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(0, CnosDBDataType.BOOLEAN)); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBy()); + } + if (Randomly.getBoolean()) { + select.setLimitClause(CnosDBConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + if (Randomly.getBoolean()) { + select.setOffsetClause(CnosDBConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + } + } + return new CnosDBSubquery(select, name); + } + + @Override + public void check() throws Exception { + s = state.getSchema(); + targetTables = s.getRandomTableNonEmptyTables(); + List tables = targetTables.getTables(); + List joins = getJoinStatements(targetTables.getColumns(), tables); + generateSelectBase(tables, joins); + } + + protected List getJoinStatements(List columns, List tables) { + return CnosDBNoRECOracle.getJoinStatements(state, columns, tables); + } + + protected void generateSelectBase(List tables, List joins) { + List tableList = tables.stream().map(CnosDBFromTable::new).collect(Collectors.toList()); + gen = new CnosDBExpressionGenerator(state).setColumns(targetTables.getColumns()); + initializeTernaryPredicateVariants(); + select = new CnosDBSelect(); + select.setFetchColumns(generateFetchColumns()); + select.setFromList(tableList); + select.setWhereClause(null); + select.setJoinClauses(joins); + } + + List generateFetchColumns() { + if (Randomly.getBooleanWithRatherLowProbability()) { + return List.of(new CnosDBColumnValue(CnosDBColumn.createDummy("*"))); + } + List fetchColumns = new ArrayList<>(); + List targetColumns = targetTables.getRandomColumnsWithOnlyOneField(); + + ArrayList columns = new ArrayList<>(); + targetColumns.forEach(column -> column.getTable().getColumns().stream() + .filter(field -> field instanceof CnosDBSchema.CnosDBFieldColumn).findFirst().ifPresent(columns::add)); + targetColumns.addAll(columns); + + targetColumns = targetColumns.stream().distinct().collect(Collectors.toList()); + + for (CnosDBColumn c : targetColumns) { + fetchColumns.add(new CnosDBColumnValue(c)); + } + return fetchColumns; + } + + @Override + protected ExpressionGenerator getGen() { + return gen; + } + +} diff --git a/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPHavingOracle.java b/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPHavingOracle.java new file mode 100644 index 000000000..283d59a23 --- /dev/null +++ b/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPHavingOracle.java @@ -0,0 +1,65 @@ +package sqlancer.cnosdb.oracle.tlp; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBComparatorHelper; +import sqlancer.cnosdb.CnosDBExpectedError; +import sqlancer.cnosdb.CnosDBGlobalState; +import sqlancer.cnosdb.CnosDBSchema.CnosDBDataType; +import sqlancer.cnosdb.CnosDBVisitor; +import sqlancer.cnosdb.ast.CnosDBExpression; + +public class CnosDBTLPHavingOracle extends CnosDBTLPBase { + + public CnosDBTLPHavingOracle(CnosDBGlobalState state) { + super(state); + } + + @Override + public void check() throws Exception { + super.check(); + havingCheck(); + } + + protected void havingCheck() throws Exception { + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(CnosDBDataType.BOOLEAN)); + } + select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + select.setHavingClause(null); + String originalQueryString = CnosDBVisitor.asString(select); + List resultSet = CnosDBComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, + CnosDBExpectedError.expectedErrors(), state); + + boolean orderBy = Randomly.getBoolean(); + if (orderBy) { + select.setOrderByClauses(gen.generateOrderBy()); + } + select.setHavingClause(predicate); + String firstQueryString = CnosDBVisitor.asString(select); + select.setHavingClause(negatedPredicate); + String secondQueryString = CnosDBVisitor.asString(select); + select.setHavingClause(isNullPredicate); + String thirdQueryString = CnosDBVisitor.asString(select); + List combinedString = new ArrayList<>(); + List secondResultSet = CnosDBComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, + thirdQueryString, combinedString, !orderBy, state, CnosDBExpectedError.expectedErrors()); + CnosDBComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state); + } + + @Override + protected CnosDBExpression generatePredicate() { + return gen.generateHavingClause(); + } + + @Override + List generateFetchColumns() { + List expressions = gen.allowAggregates(true).generateExpressions(Randomly.smallNumber() + 1); + gen.allowAggregates(false); + return expressions; + } + +} diff --git a/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPWhereOracle.java b/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPWhereOracle.java new file mode 100644 index 000000000..8e118435d --- /dev/null +++ b/src/sqlancer/cnosdb/oracle/tlp/CnosDBTLPWhereOracle.java @@ -0,0 +1,46 @@ +package sqlancer.cnosdb.oracle.tlp; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.cnosdb.CnosDBComparatorHelper; +import sqlancer.cnosdb.CnosDBExpectedError; +import sqlancer.cnosdb.CnosDBGlobalState; +import sqlancer.cnosdb.CnosDBVisitor; + +public class CnosDBTLPWhereOracle extends CnosDBTLPBase { + + public CnosDBTLPWhereOracle(CnosDBGlobalState state) { + super(state); + } + + @Override + public void check() throws Exception { + super.check(); + whereCheck(); + } + + protected void whereCheck() throws Exception { + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBy()); + } + String originalQueryString = CnosDBVisitor.asString(select); + List resultSet = CnosDBComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, + CnosDBExpectedError.expectedErrors(), state); + + select.setOrderByClauses(Collections.emptyList()); + select.setWhereClause(predicate); + String firstQueryString = CnosDBVisitor.asString(select); + select.setWhereClause(negatedPredicate); + String secondQueryString = CnosDBVisitor.asString(select); + select.setWhereClause(isNullPredicate); + String thirdQueryString = CnosDBVisitor.asString(select); + List combinedString = new ArrayList<>(); + List secondResultSet = CnosDBComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, + thirdQueryString, combinedString, Randomly.getBoolean(), state, CnosDBExpectedError.expectedErrors()); + CnosDBComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state); + } +} diff --git a/src/sqlancer/cnosdb/query/CnosDBOtherQuery.java b/src/sqlancer/cnosdb/query/CnosDBOtherQuery.java new file mode 100644 index 000000000..f0a37056c --- /dev/null +++ b/src/sqlancer/cnosdb/query/CnosDBOtherQuery.java @@ -0,0 +1,32 @@ +package sqlancer.cnosdb.query; + +import sqlancer.GlobalState; +import sqlancer.IgnoreMeException; +import sqlancer.cnosdb.client.CnosDBConnection; +import sqlancer.common.query.ExpectedErrors; + +public class CnosDBOtherQuery extends CnosDBQueryAdapter { + private static final long serialVersionUID = 1L; + + public CnosDBOtherQuery(String query, ExpectedErrors errors) { + super(query, errors); + } + + @Override + public boolean couldAffectSchema() { + return true; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + try { + globalState.getConnection().getClient().execute(query); + } catch (Exception e) { + if (this.errors.errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } + } + return true; + } +} diff --git a/src/sqlancer/cnosdb/query/CnosDBQueryAdapter.java b/src/sqlancer/cnosdb/query/CnosDBQueryAdapter.java new file mode 100644 index 000000000..115f96ffc --- /dev/null +++ b/src/sqlancer/cnosdb/query/CnosDBQueryAdapter.java @@ -0,0 +1,42 @@ +package sqlancer.cnosdb.query; + +import sqlancer.cnosdb.client.CnosDBConnection; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.Query; + +public abstract class CnosDBQueryAdapter extends Query { + private static final long serialVersionUID = 1L; + + String query; + ExpectedErrors errors; + + public CnosDBQueryAdapter(String query, ExpectedErrors errors) { + this.query = query; + this.errors = errors; + } + + @Override + public String getLogString() { + return query; + } + + @Override + public String getQueryString() { + return query; + } + + @Override + public String getUnterminatedQueryString() { + return null; + } + + @Override + public boolean couldAffectSchema() { + return false; + } + + @Override + public ExpectedErrors getExpectedErrors() { + return errors; + } +} diff --git a/src/sqlancer/cnosdb/query/CnosDBQueryProvider.java b/src/sqlancer/cnosdb/query/CnosDBQueryProvider.java new file mode 100644 index 000000000..dee38abf4 --- /dev/null +++ b/src/sqlancer/cnosdb/query/CnosDBQueryProvider.java @@ -0,0 +1,6 @@ +package sqlancer.cnosdb.query; + +@FunctionalInterface +public interface CnosDBQueryProvider { + CnosDBOtherQuery getQuery(S globalState) throws Exception; +} diff --git a/src/sqlancer/cnosdb/query/CnosDBSelectQuery.java b/src/sqlancer/cnosdb/query/CnosDBSelectQuery.java new file mode 100644 index 000000000..1c9228182 --- /dev/null +++ b/src/sqlancer/cnosdb/query/CnosDBSelectQuery.java @@ -0,0 +1,39 @@ +package sqlancer.cnosdb.query; + +import sqlancer.GlobalState; +import sqlancer.cnosdb.client.CnosDBConnection; +import sqlancer.cnosdb.client.CnosDBResultSet; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLancerResultSet; + +public class CnosDBSelectQuery extends CnosDBQueryAdapter { + private static final long serialVersionUID = 1L; + CnosDBResultSet resultSet; + + public CnosDBSelectQuery(String query, ExpectedErrors errors) { + super(query, errors); + } + + @Override + public boolean couldAffectSchema() { + return false; + } + + @Override + public > boolean execute(G globalState, String... fills) + throws Exception { + globalState.getConnection().getClient().execute(query); + return false; + } + + @Override + public > SQLancerResultSet executeAndGet(G globalState, + String... fills) throws Exception { + resultSet = globalState.getConnection().getClient().executeQuery(query); + return null; + } + + public CnosDBResultSet getResultSet() { + return resultSet; + } +} diff --git a/src/sqlancer/cockroachdb/CockroachDBBugs.java b/src/sqlancer/cockroachdb/CockroachDBBugs.java index 3a980a8c7..86d88ee21 100644 --- a/src/sqlancer/cockroachdb/CockroachDBBugs.java +++ b/src/sqlancer/cockroachdb/CockroachDBBugs.java @@ -18,37 +18,55 @@ public final class CockroachDBBugs { public static boolean bug83874 = true; // https://github.com/cockroachdb/cockroach/issues/83973 - public static boolean bug83973 = true; + public static boolean bug83973; // https://github.com/cockroachdb/cockroach/issues/83976 - public static boolean bug83976 = true; + public static boolean bug83976; + // The following bug is closed, but leave it enabled until + // the underlying interval issue is resolved. + // https://github.com/cockroachdb/cockroach/issues/84078 // https://github.com/cockroachdb/cockroach/issues/84154 public static boolean bug84154 = true; // https://github.com/cockroachdb/cockroach/issues/85356 - public static boolean bug85356 = true; + public static boolean bug85356; // https://github.com/cockroachdb/cockroach/issues/85371 - public static boolean bug85371 = true; + public static boolean bug85371; // https://github.com/cockroachdb/cockroach/issues/85389 - public static boolean bug85389 = true; + public static boolean bug85389; // https://github.com/cockroachdb/cockroach/issues/85390 - public static boolean bug85390 = true; + public static boolean bug85390; // https://github.com/cockroachdb/cockroach/issues/85393 - public static boolean bug85393 = true; + public static boolean bug85393; // https://github.com/cockroachdb/cockroach/issues/85394 public static boolean bug85394 = true; // https://github.com/cockroachdb/cockroach/issues/85441 - public static boolean bug85441 = true; + public static boolean bug85441; // https://github.com/cockroachdb/cockroach/issues/85499 - public static boolean bug85499 = true; + public static boolean bug85499; + + // https://github.com/cockroachdb/cockroach/issues/88037 + public static boolean bug88037; + + // https://github.com/cockroachdb/cockroach/issues/85230 + public static boolean bug85230 = true; + + // https://github.com/cockroachdb/cockroach/issues/131640 + public static boolean bug131640 = true; + + // https://github.com/cockroachdb/cockroach/issues/131647 + public static boolean bug131647 = true; + + // https://github.com/cockroachdb/cockroach/issues/131875 + public static boolean bug131875 = true; private CockroachDBBugs() { } diff --git a/src/sqlancer/cockroachdb/CockroachDBErrors.java b/src/sqlancer/cockroachdb/CockroachDBErrors.java index 66a5e5d7d..134ad35aa 100644 --- a/src/sqlancer/cockroachdb/CockroachDBErrors.java +++ b/src/sqlancer/cockroachdb/CockroachDBErrors.java @@ -1,5 +1,8 @@ package sqlancer.cockroachdb; +import java.util.ArrayList; +import java.util.List; + import sqlancer.common.query.ExpectedErrors; public final class CockroachDBErrors { @@ -7,7 +10,9 @@ public final class CockroachDBErrors { private CockroachDBErrors() { } - public static void addExpressionErrors(ExpectedErrors errors) { + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add(" non-streaming operator encountered when vectorize=auto"); if (CockroachDBBugs.bug46915) { @@ -19,7 +24,8 @@ public static void addExpressionErrors(ExpectedErrors errors) { } errors.add("exceeds supported timestamp bounds"); - + errors.add("expected STORED COMPUTED COLUMN expression to have type bytes"); + errors.add("volatile functions are not allowed in STORED COMPUTED COLUMN"); errors.add("cannot cast negative integer to bit varying with unbounded width"); errors.add("negative value for LIMIT"); @@ -74,6 +80,10 @@ public static void addExpressionErrors(ExpectedErrors errors) { errors.add(" unsupported comparison operator: NOT LIKE != "); errors.add("expected DEFAULT expression to have type bytes"); + errors.add("expected DEFAULT (in CREATE TABLE) expression to have type bytes"); + errors.add("expected DEFAULT (in CREATE VIEW) expression to have type bytes"); + errors.add("expected DEFAULT (in SET DEFAULT) expression to have type bytes"); + errors.add("expected DEFAULT (in ADD COLUMN) expression to have type bytes"); errors.add("value type string doesn't match type bytes of column"); errors.add("as decimal, found type: int"); errors.add("to be of type decimal, found type float"); @@ -90,20 +100,16 @@ public static void addExpressionErrors(ExpectedErrors errors) { errors.add("LOOKUP can only be used with INNER or LEFT joins"); // TODO errors.add("ambiguous binary operator: || "); - errors.add(" ERROR: unsupported binary operator: || (desired )"); - errors.add("unsupported binary operator: || (desired )"); - errors.add("incompatible value type: unsupported binary operator: || (desired )"); - errors.add("unsupported binary operator: || (desired )"); - errors.add("unsupported binary operator: || (desired )"); + errors.add("unsupported binary operator"); errors.add("parsing as type timestamp: empty or blank input"); errors.add("parsing as type timestamp: field"); errors.add("as type time"); errors.add("as TimeTZ"); errors.add("as type decimal"); - addIntervalTypeErrors(errors); - addFunctionErrors(errors); - addGroupByErrors(errors); - addJoinTypes(errors); + errors.addAll(getIntervalTypeErrors()); + errors.addAll(getFunctionErrors()); + errors.addAll(getGroupByErrors()); + errors.addAll(getJoinTypes()); errors.add("as int4, found type: decimal"); errors.add("to be of type int2, found type decimal"); errors.add("to be of type int, found type decimal"); // arithmetic overflows @@ -111,7 +117,6 @@ public static void addExpressionErrors(ExpectedErrors errors) { errors.add("numeric constant out of int64 range"); errors.add("unknown signature: overlay"); errors.add("unknown signature: substring"); - errors.add("unsupported binary operator: + (desired )"); errors.add("unsupported comparison operator"); errors.add("unknown signature: chr(decimal) (desired )"); errors.add("unknown signature: to_english(decimal) (desired )"); @@ -126,7 +131,6 @@ public static void addExpressionErrors(ExpectedErrors errors) { errors.add("has type decimal"); errors.add("to be of type decimal, found type int"); errors.add("value type decimal doesn't match type int"); - errors.add("unsupported binary operator: / (desired )"); errors.add("(desired )"); errors.add("(desired )"); errors.add("(desired )"); @@ -142,10 +146,7 @@ public static void addExpressionErrors(ExpectedErrors errors) { errors.add("exists but is not a directory"); // TODO - errors.add("could not parse JSON: trailing characters after JSON document"); - errors.add("could not parse JSON: unable to decode JSON: invalid character"); - errors.add("could not parse JSON: unable to decode JSON: EOF"); - errors.add("could not parse JSON: unable to decode JSON: unexpected EOF"); + errors.add("could not parse JSON"); errors.add("can't order by column type jsonb"); errors.add("odd length hex string"); @@ -186,6 +187,9 @@ public static void addExpressionErrors(ExpectedErrors errors) { if (CockroachDBBugs.bug85499) { errors.add("estimated row count must be non-zero"); } + if (CockroachDBBugs.bug88037) { + errors.add("expected required columns to be a subset of output columns"); + } errors.add("unable to vectorize execution plan"); // SET vectorize=experimental_always; errors.add(" mismatched physical types at index"); // SET vectorize=experimental_always; @@ -202,10 +206,21 @@ public static void addExpressionErrors(ExpectedErrors errors) { errors.add("argument of OFFSET must be type int, not type decimal"); errors.add("ERROR: for SELECT DISTINCT, ORDER BY expressions must appear in select list"); - addArrayErrors(errors); + errors.add("incompatible IF expressions"); + + errors.addAll(getArrayErrors()); + errors.addAll(getComputedColumnErrors()); + + return errors; } - private static void addArrayErrors(ExpectedErrors errors) { + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + private static List getArrayErrors() { + ArrayList errors = new ArrayList<>(); + // arrays errors.add("cannot determine type of empty array"); errors.add("unknown signature: max(unknown[])"); @@ -251,18 +266,29 @@ private static void addArrayErrors(ExpectedErrors errors) { errors.add("to be of type int[], found type decimal[]"); errors.add("to be of type unknown[]"); // IF with null array + + return errors; } - private static void addIntervalTypeErrors(ExpectedErrors errors) { + private static List getIntervalTypeErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("overflow during Encode"); errors.add("type interval"); + + return errors; } - private static void addJoinTypes(ExpectedErrors errors) { + private static List getJoinTypes() { + ArrayList errors = new ArrayList<>(); + errors.add("JOIN/USING types"); + + return errors; } - private static void addGroupByErrors(ExpectedErrors errors) { + private static List getGroupByErrors() { + ArrayList errors = new ArrayList<>(); errors.add("non-integer constant in GROUP BY"); // https://github.com/cockroachdb/cockroach/pull/46649 -> aggregates on NULL are @@ -284,9 +310,11 @@ private static void addGroupByErrors(ExpectedErrors errors) { errors.add("unknown signature: abs(string)"); errors.add("unknown signature: acos(string)"); + return errors; } - private static void addFunctionErrors(ExpectedErrors errors) { + private static List getFunctionErrors() { + ArrayList errors = new ArrayList<>(); // functions errors.add("abs of min integer value (-9223372036854775808) not defined"); // ABS errors.add("the input string must not be empty"); // ASCII @@ -302,10 +330,30 @@ private static void addFunctionErrors(ExpectedErrors errors) { errors.add("substring(): negative substring length"); // substring errors.add("negative substring length"); // substring errors.add("must be greater than zero"); // split_part + + return errors; } - public static void addTransactionErrors(ExpectedErrors errors) { + public static List getTransactionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("current transaction is aborted"); + + return errors; + } + + public static void addTransactionErrors(ExpectedErrors errors) { + errors.addAll(getTransactionErrors()); + } + + private static List getComputedColumnErrors() { + ArrayList errors = new ArrayList<>(); + + // computed columns + errors.add("computed column expressions cannot reference computed columns"); + errors.add("STORED COMPUTED COLUMN expression cannot reference computed columns"); + + return errors; } } diff --git a/src/sqlancer/cockroachdb/CockroachDBOptions.java b/src/sqlancer/cockroachdb/CockroachDBOptions.java index aa41259b0..ce8a207d6 100644 --- a/src/sqlancer/cockroachdb/CockroachDBOptions.java +++ b/src/sqlancer/cockroachdb/CockroachDBOptions.java @@ -1,7 +1,5 @@ package sqlancer.cockroachdb; -import java.sql.SQLException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -9,18 +7,6 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.cockroachdb.CockroachDBOptions.CockroachDBOracleFactory; -import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; -import sqlancer.cockroachdb.oracle.CockroachDBNoRECOracle; -import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPAggregateOracle; -import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPDistinctOracle; -import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPExtendedWhereOracle; -import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPGroupByOracle; -import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPHavingOracle; -import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPWhereOracle; -import sqlancer.common.oracle.CompositeTestOracle; -import sqlancer.common.oracle.TestOracle; @Parameters(separators = "=", commandDescription = "CockroachDB (default port: " + CockroachDBOptions.DEFAULT_PORT + " default host: " + CockroachDBOptions.DEFAULT_HOST + ")") @@ -31,67 +17,6 @@ public class CockroachDBOptions implements DBMSSpecificOptions { - NOREC { - @Override - public TestOracle create(CockroachDBGlobalState globalState) throws SQLException { - return new CockroachDBNoRECOracle(globalState); - } - }, - AGGREGATE { - - @Override - public TestOracle create(CockroachDBGlobalState globalState) throws SQLException { - return new CockroachDBTLPAggregateOracle(globalState); - } - - }, - GROUP_BY { - @Override - public TestOracle create(CockroachDBGlobalState globalState) throws SQLException { - return new CockroachDBTLPGroupByOracle(globalState); - } - }, - HAVING { - @Override - public TestOracle create(CockroachDBGlobalState globalState) throws SQLException { - return new CockroachDBTLPHavingOracle(globalState); - } - }, - WHERE { - @Override - public TestOracle create(CockroachDBGlobalState globalState) throws SQLException { - return new CockroachDBTLPWhereOracle(globalState); - } - }, - DISTINCT { - @Override - public TestOracle create(CockroachDBGlobalState globalState) throws SQLException { - return new CockroachDBTLPDistinctOracle(globalState); - } - }, - EXTENDED_WHERE { - @Override - public TestOracle create(CockroachDBGlobalState globalState) throws SQLException { - return new CockroachDBTLPExtendedWhereOracle(globalState); - } - }, - QUERY_PARTITIONING { - @Override - public TestOracle create(CockroachDBGlobalState globalState) throws SQLException { - List oracles = new ArrayList<>(); - oracles.add(new CockroachDBTLPAggregateOracle(globalState)); - oracles.add(new CockroachDBTLPHavingOracle(globalState)); - oracles.add(new CockroachDBTLPWhereOracle(globalState)); - oracles.add(new CockroachDBTLPGroupByOracle(globalState)); - oracles.add(new CockroachDBTLPExtendedWhereOracle(globalState)); - oracles.add(new CockroachDBTLPDistinctOracle(globalState)); - return new CompositeTestOracle(oracles, globalState); - } - }; - - } - @Parameter(names = { "--test-hash-indexes" }, description = "Test the USING HASH WITH BUCKET_COUNT=n_buckets option in CREATE INDEX") public boolean testHashIndexes = true; @@ -99,9 +24,11 @@ public TestOracle create(CockroachDBGlobalState globalState) throws SQLException @Parameter(names = { "--test-temp-tables" }, description = "Test TEMPORARY tables") public boolean testTempTables; // default: false https://github.com/cockroachdb/cockroach/issues/85388 - @Parameter(names = { - "--increased-vectorization" }, description = "Generate VECTORIZE=on with a higher probability (which found a number of bugs in the past)") - public boolean makeVectorizationMoreLikely = true; + @Parameter(names = { "--max-num-tables" }, description = "The maximum number of tables that can be created") + public int maxNumTables = 10; + + @Parameter(names = { "--max-num-indexes" }, description = "The maximum number of indexes that can be created") + public int maxNumIndexes = 20; @Override public List getTestOracleFactory() { diff --git a/src/sqlancer/cockroachdb/CockroachDBOracleFactory.java b/src/sqlancer/cockroachdb/CockroachDBOracleFactory.java new file mode 100644 index 000000000..96fbc22ce --- /dev/null +++ b/src/sqlancer/cockroachdb/CockroachDBOracleFactory.java @@ -0,0 +1,138 @@ +package sqlancer.cockroachdb; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import sqlancer.IgnoreMeException; +import sqlancer.OracleFactory; +import sqlancer.cockroachdb.gen.CockroachDBExpressionGenerator; +import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPAggregateOracle; +import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPDistinctOracle; +import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPExtendedWhereOracle; +import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPGroupByOracle; +import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPHavingOracle; +import sqlancer.common.oracle.CERTOracle; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLancerResultSet; + +public enum CockroachDBOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create( + CockroachDBProvider.CockroachDBGlobalState globalState) throws SQLException { + CockroachDBExpressionGenerator gen = new CockroachDBExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(CockroachDBErrors.getExpressionErrors()) + .with(CockroachDBErrors.getTransactionErrors()).with("unable to vectorize execution plan") // SET + // vectorize=experimental_always; + .with(" mismatched physical types at index") // SET vectorize=experimental_always; + .build(); + return new NoRECOracle<>(globalState, gen, errors); + } + }, + AGGREGATE { + @Override + public TestOracle create( + CockroachDBProvider.CockroachDBGlobalState globalState) throws SQLException { + return new CockroachDBTLPAggregateOracle(globalState); + } + + }, + GROUP_BY { + @Override + public TestOracle create( + CockroachDBProvider.CockroachDBGlobalState globalState) throws SQLException { + return new CockroachDBTLPGroupByOracle(globalState); + } + }, + HAVING { + @Override + public TestOracle create( + CockroachDBProvider.CockroachDBGlobalState globalState) throws SQLException { + return new CockroachDBTLPHavingOracle(globalState); + } + }, + WHERE { + @Override + public TestOracle create( + CockroachDBProvider.CockroachDBGlobalState globalState) throws SQLException { + CockroachDBExpressionGenerator gen = new CockroachDBExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(CockroachDBErrors.getExpressionErrors()) + .with("GROUP BY term out of range").build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + DISTINCT { + @Override + public TestOracle create( + CockroachDBProvider.CockroachDBGlobalState globalState) throws SQLException { + return new CockroachDBTLPDistinctOracle(globalState); + } + }, + EXTENDED_WHERE { + @Override + public TestOracle create( + CockroachDBProvider.CockroachDBGlobalState globalState) throws SQLException { + return new CockroachDBTLPExtendedWhereOracle(globalState); + } + }, + QUERY_PARTITIONING { + @Override + public TestOracle create( + CockroachDBProvider.CockroachDBGlobalState globalState) throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(AGGREGATE.create(globalState)); + oracles.add(HAVING.create(globalState)); + oracles.add(WHERE.create(globalState)); + oracles.add(GROUP_BY.create(globalState)); + oracles.add(EXTENDED_WHERE.create(globalState)); + oracles.add(DISTINCT.create(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + }, + CERT { + @Override + public TestOracle create( + CockroachDBProvider.CockroachDBGlobalState globalState) throws SQLException { + CockroachDBExpressionGenerator gen = new CockroachDBExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(CockroachDBErrors.getExpressionErrors()) + .build(); + CERTOracle.CheckedFunction> rowCountParser = (rs) -> { + String content = rs.getString(1); + if (content.contains("count:")) { + try { + long number = Long.parseLong(content.split("count: ")[1].split(" ")[0].replace(",", "")); + return Optional.of(number); + } catch (Exception e) { // To avoid the situation that no number is found + } + } + return Optional.empty(); + }; + CERTOracle.CheckedFunction> queryPlanParser = (rs) -> { + String content = rs.getString(1); + if (content.contains("• ")) { + String operation = content.split("• ")[1].split(" ")[0]; + if (CockroachDBBugs.bug131875 && (operation.equals("distinct") || operation.equals("limit"))) { + throw new IgnoreMeException(); + } + return Optional.of(operation); + } + return Optional.empty(); + }; + + return new CERTOracle<>(globalState, gen, expectedErrors, rowCountParser, queryPlanParser); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + }; + +} diff --git a/src/sqlancer/cockroachdb/CockroachDBProvider.java b/src/sqlancer/cockroachdb/CockroachDBProvider.java index abc7e6e20..66e80fd59 100644 --- a/src/sqlancer/cockroachdb/CockroachDBProvider.java +++ b/src/sqlancer/cockroachdb/CockroachDBProvider.java @@ -1,5 +1,6 @@ package sqlancer.cockroachdb; +import java.io.IOException; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; @@ -23,6 +24,8 @@ import sqlancer.cockroachdb.gen.CockroachDBCommentOnGenerator; import sqlancer.cockroachdb.gen.CockroachDBCreateStatisticsGenerator; import sqlancer.cockroachdb.gen.CockroachDBDeleteGenerator; +import sqlancer.cockroachdb.gen.CockroachDBDropTableGenerator; +import sqlancer.cockroachdb.gen.CockroachDBDropViewGenerator; import sqlancer.cockroachdb.gen.CockroachDBIndexGenerator; import sqlancer.cockroachdb.gen.CockroachDBInsertGenerator; import sqlancer.cockroachdb.gen.CockroachDBRandomQuerySynthesizer; @@ -36,6 +39,7 @@ import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLQueryProvider; +import sqlancer.common.query.SQLancerResultSet; @AutoService(DatabaseProvider.class) public class CockroachDBProvider extends SQLProviderAdapter { @@ -45,15 +49,17 @@ public CockroachDBProvider() { } public enum Action { - INSERT(CockroachDBInsertGenerator::insert), // - TRUNCATE(CockroachDBTruncateGenerator::truncate), // + CREATE_TABLE(CockroachDBTableGenerator::generate), CREATE_INDEX(CockroachDBIndexGenerator::create), // + CREATE_VIEW(CockroachDBViewGenerator::generate), // CREATE_STATISTICS(CockroachDBCreateStatisticsGenerator::create), // - SET_SESSION(CockroachDBSetSessionGenerator::create), // - CREATE_INDEX(CockroachDBIndexGenerator::create), // + INSERT(CockroachDBInsertGenerator::insert), // UPDATE(CockroachDBUpdateGenerator::gen), // - CREATE_VIEW(CockroachDBViewGenerator::generate), // + SET_SESSION(CockroachDBSetSessionGenerator::create), // SET_CLUSTER_SETTING(CockroachDBSetClusterSettingGenerator::create), // DELETE(CockroachDBDeleteGenerator::delete), // + TRUNCATE(CockroachDBTruncateGenerator::truncate), // + DROP_TABLE(CockroachDBDropTableGenerator::drop), // + DROP_VIEW(CockroachDBDropViewGenerator::drop), // COMMENT_ON(CockroachDBCommentOnGenerator::comment), // SHOW(CockroachDBShowGenerator::show), // TRANSACTION((g) -> { @@ -65,8 +71,7 @@ public enum Action { ExpectedErrors errors = new ExpectedErrors(); if (Randomly.getBoolean()) { sb.append("("); - sb.append(Randomly.nonEmptySubset("VERBOSE", "TYPES", "OPT", "DISTSQL", "VEC").stream() - .collect(Collectors.joining(", "))); + sb.append(Randomly.fromOptions("VERBOSE", "TYPES", "OPT", "DISTSQL", "VEC")); sb.append(") "); errors.add("cannot set EXPLAIN mode more than once"); errors.add("unable to vectorize execution plan"); @@ -125,13 +130,13 @@ public void generateDatabase(CockroachDBGlobalState globalState) throws Exceptio QueryManager manager = globalState.getManager(); MainOptions options = globalState.getOptions(); List standardSettings = new ArrayList<>(); - standardSettings.add("--Don't send automatic bug reports\n" - + "SET CLUSTER SETTING debug.panic_on_failed_assertions = true;"); + standardSettings.add("--Don't send automatic bug reports"); + standardSettings.add("SET CLUSTER SETTING debug.panic_on_failed_assertions = true;"); standardSettings.add("SET CLUSTER SETTING diagnostics.reporting.enabled = false;"); standardSettings.add("SET CLUSTER SETTING diagnostics.reporting.send_crash_reports = false;"); - standardSettings.add("-- Disable the collection of metrics and hope that it helps performance\n" - + "SET CLUSTER SETTING sql.metrics.statement_details.enabled = 'off'"); + standardSettings.add("-- Disable the collection of metrics and hope that it helps performance"); + standardSettings.add("SET CLUSTER SETTING sql.metrics.statement_details.enabled = 'off'"); standardSettings.add("SET CLUSTER SETTING sql.metrics.statement_details.plan_collection.enabled = 'off'"); standardSettings.add("SET CLUSTER SETTING sql.stats.automatic_collection.enabled = 'off'"); standardSettings.add("SET CLUSTER SETTING timeseries.storage.enabled = 'off'"); @@ -199,6 +204,9 @@ public void generateDatabase(CockroachDBGlobalState globalState) throws Exceptio */ break; case TRANSACTION: + case CREATE_TABLE: + case DROP_TABLE: + case DROP_VIEW: nrPerformed = 0; // r.getInteger(0, 0); break; default: @@ -242,8 +250,15 @@ public void generateDatabase(CockroachDBGlobalState globalState) throws Exceptio } total--; } - if (globalState.getDbmsSpecificOptions().makeVectorizationMoreLikely && Randomly.getBoolean()) { - manager.execute(new SQLQueryAdapter("SET vectorize=on;")); + + if (globalState.getDbmsSpecificOptions().getTestOracleFactory().stream() + .anyMatch((o) -> o == CockroachDBOracleFactory.CERT)) { + // Enfore statistic collected for all tables + ExpectedErrors errors = new ExpectedErrors(); + CockroachDBErrors.addExpressionErrors(errors); + for (CockroachDBTable table : globalState.getSchema().getDatabaseTables()) { + globalState.executeStatement(new SQLQueryAdapter("ANALYZE " + table.getName() + ";", errors)); + } } } @@ -273,7 +288,7 @@ public SQLConnection createDatabase(CockroachDBGlobalState globalState) throws S s.execute(createDatabaseCommand); } con.close(); - con = DriverManager.getConnection("jdbc:postgresql://localhost:26257/" + databaseName, + con = DriverManager.getConnection(String.format("jdbc:postgresql://%s:%d/%s", host, port, databaseName), globalState.getOptions().getUserName(), globalState.getOptions().getPassword()); return new SQLConnection(con); } @@ -283,4 +298,68 @@ public String getDBMSName() { return "cockroachdb"; } + @Override + public String getQueryPlan(String selectStr, CockroachDBGlobalState globalState) throws Exception { + String queryPlan = ""; + String explainQuery = "EXPLAIN (OPT) " + selectStr; + if (globalState.getOptions().logEachSelect()) { + globalState.getLogger().writeCurrent(explainQuery); + try { + globalState.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + SQLQueryAdapter q = new SQLQueryAdapter(explainQuery); + boolean afterProjection = false; // Remove the concrete expression after each Projection operator + try (SQLancerResultSet rs = q.executeAndGet(globalState)) { + if (rs != null) { + while (rs.next()) { + String targetQueryPlan = rs.getString(1).replace("└──", "").replace("├──", "").replace("│", "") + .trim() + ";"; // Unify format + if (afterProjection) { + afterProjection = false; + continue; + } + if (targetQueryPlan.startsWith("projections")) { + afterProjection = true; + } + // Remove all concrete expressions by keywords + if (targetQueryPlan.contains(">") || targetQueryPlan.contains("<") || targetQueryPlan.contains("=") + || targetQueryPlan.contains("*") || targetQueryPlan.contains("+") + || targetQueryPlan.contains("'")) { + continue; + } + queryPlan += targetQueryPlan; + } + } + } catch (AssertionError e) { + throw new AssertionError("Explain failed: " + explainQuery); + } + + return queryPlan; + } + + @Override + protected double[] initializeWeightedAverageReward() { + return new double[Action.values().length]; + } + + @Override + protected void executeMutator(int index, CockroachDBGlobalState globalState) throws Exception { + SQLQueryAdapter queryMutateTable = Action.values()[index].getQuery(globalState); + globalState.executeStatement(queryMutateTable); + } + + @Override + public boolean addRowsToAllTables(CockroachDBGlobalState globalState) throws Exception { + List tablesNoRow = globalState.getSchema().getDatabaseTables().stream() + .filter(t -> t.getNrRows(globalState) == 0).collect(Collectors.toList()); + for (CockroachDBTable table : tablesNoRow) { + SQLQueryAdapter queryAddRows = CockroachDBInsertGenerator.insert(globalState, table); + globalState.executeStatement(queryAddRows); + } + return true; + } + } diff --git a/src/sqlancer/cockroachdb/CockroachDBSchema.java b/src/sqlancer/cockroachdb/CockroachDBSchema.java index c245f4cce..cbde577cb 100644 --- a/src/sqlancer/cockroachdb/CockroachDBSchema.java +++ b/src/sqlancer/cockroachdb/CockroachDBSchema.java @@ -182,6 +182,7 @@ public CockroachDBColumn(String name, CockroachDBCompositeDataType columnType, b this.isNullable = isNullable; } + @Override public boolean isPrimaryKey() { return isPrimaryKey; } @@ -208,6 +209,10 @@ public CockroachDBTables getRandomTableNonEmptyTables() { return new CockroachDBTables(Randomly.nonEmptySubset(getDatabaseTables())); } + public CockroachDBTables getRandomTableNonEmptyTables(int nr) { + return new CockroachDBTables(Randomly.nonEmptySubsetLeast(getDatabaseTables(), nr)); + } + private static CockroachDBCompositeDataType getColumnType(String typeString) { if (typeString.endsWith("[]")) { String substring = typeString.substring(0, typeString.length() - 2); @@ -288,7 +293,7 @@ public static CockroachDBSchema fromConnection(SQLConnection con, String databas for (String tableName : tableNames) { List databaseColumns = getTableColumns(con, tableName); List indexes = getIndexes(con, tableName); - boolean isView = tableName.startsWith("v"); + boolean isView = matchesViewName(tableName); CockroachDBTable t = new CockroachDBTable(tableName, databaseColumns, indexes, isView); for (CockroachDBColumn c : databaseColumns) { c.setTable(t); diff --git a/src/sqlancer/cockroachdb/CockroachDBToStringVisitor.java b/src/sqlancer/cockroachdb/CockroachDBToStringVisitor.java index 17a09a401..67abfdbbe 100644 --- a/src/sqlancer/cockroachdb/CockroachDBToStringVisitor.java +++ b/src/sqlancer/cockroachdb/CockroachDBToStringVisitor.java @@ -102,9 +102,9 @@ public void visit(CockroachDBSelect select) { sb.append(" HAVING "); visit(select.getHavingClause()); } - if (!select.getOrderByExpressions().isEmpty()) { + if (!select.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - visit(select.getOrderByExpressions()); + visit(select.getOrderByClauses()); } if (select.getLimitClause() != null) { sb.append(" LIMIT "); @@ -142,54 +142,65 @@ public void visit(CockroachDBJoin join) { switch (join.getJoinType()) { case INNER: sb.append(" INNER "); - potentiallyAddHint(); + potentiallyAddHint(false); sb.append("JOIN "); visit(join.getRightTable()); sb.append(" ON "); visit(join.getOnCondition()); break; - case NATURAL: - sb.append(" NATURAL "); - // potentiallyAddHint(); + case LEFT: + sb.append(" LEFT"); + sb.append(" OUTER "); + potentiallyAddHint(true); sb.append("JOIN "); visit(join.getRightTable()); + sb.append(" ON "); + visit(join.getOnCondition()); break; - case CROSS: - sb.append(" CROSS "); - potentiallyAddHint(); + case RIGHT: + sb.append(" RIGHT"); + sb.append(" OUTER "); + potentiallyAddHint(true); sb.append("JOIN "); visit(join.getRightTable()); + sb.append(" ON "); + visit(join.getOnCondition()); break; - case OUTER: - sb.append(" "); - switch (join.getOuterType()) { - case FULL: - sb.append("FULL"); - break; - case LEFT: - sb.append("LEFT"); - break; - case RIGHT: - sb.append("RIGHT"); - break; - default: - throw new AssertionError(); - } + case FULL: + sb.append(" FULL"); sb.append(" OUTER "); - potentiallyAddHint(); + potentiallyAddHint(true); sb.append("JOIN "); visit(join.getRightTable()); sb.append(" ON "); visit(join.getOnCondition()); break; + case CROSS: + sb.append(" CROSS "); + potentiallyAddHint(false); + sb.append("JOIN "); + visit(join.getRightTable()); + break; + case NATURAL: + sb.append(" NATURAL "); + // potentiallyAddHint(false); + sb.append("JOIN "); + visit(join.getRightTable()); + break; default: throw new AssertionError(); } } - private void potentiallyAddHint() { + private void potentiallyAddHint(boolean isOuter) { if (Randomly.getBoolean()) { - sb.append(Randomly.fromOptions("HASH", "MERGE", "LOOKUP")); + return; + } else { + if (isOuter) { + sb.append(Randomly.fromOptions("HASH", "MERGE", "LOOKUP")); + } else { + sb.append(Randomly.fromOptions("HASH", "MERGE")); + } sb.append(" "); } } diff --git a/src/sqlancer/cockroachdb/ast/CockroachDBExpression.java b/src/sqlancer/cockroachdb/ast/CockroachDBExpression.java index 555fb97fd..d0ac07310 100644 --- a/src/sqlancer/cockroachdb/ast/CockroachDBExpression.java +++ b/src/sqlancer/cockroachdb/ast/CockroachDBExpression.java @@ -1,5 +1,8 @@ package sqlancer.cockroachdb.ast; -public interface CockroachDBExpression { +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBColumn; +import sqlancer.common.ast.newast.Expression; + +public interface CockroachDBExpression extends Expression { } diff --git a/src/sqlancer/cockroachdb/ast/CockroachDBJoin.java b/src/sqlancer/cockroachdb/ast/CockroachDBJoin.java index 740a7901c..a8cba22f9 100644 --- a/src/sqlancer/cockroachdb/ast/CockroachDBJoin.java +++ b/src/sqlancer/cockroachdb/ast/CockroachDBJoin.java @@ -1,28 +1,31 @@ package sqlancer.cockroachdb.ast; +import java.util.Arrays; + import sqlancer.Randomly; +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBColumn; +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBTable; +import sqlancer.common.ast.newast.Join; -public class CockroachDBJoin implements CockroachDBExpression { +public class CockroachDBJoin + implements CockroachDBExpression, Join { private final CockroachDBExpression leftTable; private final CockroachDBExpression rightTable; - private final JoinType joinType; - private final CockroachDBExpression onCondition; - private OuterType outerType; + private JoinType joinType; + private CockroachDBExpression onCondition; public enum JoinType { - INNER, NATURAL, CROSS, OUTER; + INNER, LEFT, RIGHT, FULL, CROSS, NATURAL; public static JoinType getRandom() { return Randomly.fromOptions(values()); } - } - - public enum OuterType { - FULL, LEFT, RIGHT; - public static OuterType getRandom() { - return Randomly.fromOptions(values()); + public static JoinType getRandomExcept(JoinType... exclude) { + JoinType[] values = Arrays.stream(values()).filter(m -> !Arrays.asList(exclude).contains(m)) + .toArray(JoinType[]::new); + return Randomly.fromOptions(values); } } @@ -42,35 +45,29 @@ public CockroachDBExpression getRightTable() { return rightTable; } - public JoinType getJoinType() { - return joinType; - } - - public CockroachDBExpression getOnCondition() { - return onCondition; - } - - public static CockroachDBJoin createNaturalJoin(CockroachDBExpression left, CockroachDBExpression right) { - return new CockroachDBJoin(left, right, JoinType.NATURAL, null); + public void setJoinType(JoinType joinType) { + this.joinType = joinType; } - public static CockroachDBJoin createCrossJoin(CockroachDBExpression left, CockroachDBExpression right) { - return new CockroachDBJoin(left, right, JoinType.CROSS, null); + public JoinType getJoinType() { + return joinType; } - public static CockroachDBJoin createOuterJoin(CockroachDBExpression left, CockroachDBExpression right, - OuterType type, CockroachDBExpression onClause) { - CockroachDBJoin join = new CockroachDBJoin(left, right, JoinType.OUTER, onClause); - join.setOuterType(type); - return join; + @Override + public void setOnClause(CockroachDBExpression onCondition) { + this.onCondition = onCondition; } - private void setOuterType(OuterType outerType) { - this.outerType = outerType; + public CockroachDBExpression getOnCondition() { + return onCondition; } - public OuterType getOuterType() { - return outerType; + public static CockroachDBJoin createJoin(CockroachDBExpression left, CockroachDBExpression right, JoinType type, + CockroachDBExpression onClause) { + if (type.compareTo(JoinType.CROSS) >= 0) { + return new CockroachDBJoin(left, right, type, null); + } else { + return new CockroachDBJoin(left, right, type, onClause); + } } - } diff --git a/src/sqlancer/cockroachdb/ast/CockroachDBSelect.java b/src/sqlancer/cockroachdb/ast/CockroachDBSelect.java index 2c7d33ff4..c043a9207 100644 --- a/src/sqlancer/cockroachdb/ast/CockroachDBSelect.java +++ b/src/sqlancer/cockroachdb/ast/CockroachDBSelect.java @@ -1,8 +1,16 @@ package sqlancer.cockroachdb.ast; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBColumn; +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBTable; +import sqlancer.cockroachdb.CockroachDBVisitor; import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; -public class CockroachDBSelect extends SelectBase implements CockroachDBExpression { +public class CockroachDBSelect extends SelectBase implements CockroachDBExpression, + Select { private boolean isDistinct; @@ -14,4 +22,21 @@ public void setDistinct(boolean isDistinct) { this.isDistinct = isDistinct; } + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (CockroachDBExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (CockroachDBJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return CockroachDBVisitor.asString(this); + } + } diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBDropTableGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBDropTableGenerator.java new file mode 100644 index 000000000..3834c8757 --- /dev/null +++ b/src/sqlancer/cockroachdb/gen/CockroachDBDropTableGenerator.java @@ -0,0 +1,35 @@ +package sqlancer.cockroachdb.gen; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; + +public final class CockroachDBDropTableGenerator { + + private CockroachDBDropTableGenerator() { + } + + public static SQLQueryAdapter drop(CockroachDBGlobalState globalState) { + if (globalState.getSchema().getTables(t -> !t.isView()).size() <= 1) { + throw new IgnoreMeException(); + } + + ExpectedErrors errors = new ExpectedErrors(); + errors.add("is referenced by foreign key"); + + StringBuilder sb = new StringBuilder(); + sb.append("DROP"); + sb.append(" TABLE"); + sb.append(" "); + sb.append(globalState.getSchema().getRandomTable(t -> !t.isView()).getName()); + + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("CASCADE", "RESTRICT")); + } + return new SQLQueryAdapter(sb.toString(), true); + } + +} diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBDropViewGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBDropViewGenerator.java new file mode 100644 index 000000000..5d440b564 --- /dev/null +++ b/src/sqlancer/cockroachdb/gen/CockroachDBDropViewGenerator.java @@ -0,0 +1,41 @@ +package sqlancer.cockroachdb.gen; + +import sqlancer.Randomly; +import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; + +public final class CockroachDBDropViewGenerator { + + private CockroachDBDropViewGenerator() { + } + + public static SQLQueryAdapter drop(CockroachDBGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + errors.add("is referenced by foreign key"); + + StringBuilder sb = new StringBuilder(); + sb.append("DROP"); + if (Randomly.getBoolean()) { + sb.append(" MATERIALIZED"); + } + sb.append(" VIEW"); + sb.append(" "); + if (Randomly.getBooleanWithRatherLowProbability()) { + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(globalState.getSchema().getRandomTable(t -> t.isView()).getName()); + } + } else { + sb.append(globalState.getSchema().getRandomTable(t -> t.isView()).getName()); + } + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("CASCADE", "RESTRICT")); + } + return new SQLQueryAdapter(sb.toString(), true); + } + +} diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBExpressionGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBExpressionGenerator.java index 6326bcc95..8aeb5492b 100644 --- a/src/sqlancer/cockroachdb/gen/CockroachDBExpressionGenerator.java +++ b/src/sqlancer/cockroachdb/gen/CockroachDBExpressionGenerator.java @@ -2,14 +2,18 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import java.util.stream.Collectors; import sqlancer.Randomly; +import sqlancer.cockroachdb.CockroachDBBugs; import sqlancer.cockroachdb.CockroachDBCommon; import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBColumn; import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBCompositeDataType; import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBDataType; +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBTable; +import sqlancer.cockroachdb.CockroachDBVisitor; import sqlancer.cockroachdb.ast.CockroachDBAggregate; import sqlancer.cockroachdb.ast.CockroachDBAggregate.CockroachDBAggregateFunction; import sqlancer.cockroachdb.ast.CockroachDBBetweenOperation; @@ -29,6 +33,8 @@ import sqlancer.cockroachdb.ast.CockroachDBExpression; import sqlancer.cockroachdb.ast.CockroachDBFunction; import sqlancer.cockroachdb.ast.CockroachDBInOperation; +import sqlancer.cockroachdb.ast.CockroachDBJoin; +import sqlancer.cockroachdb.ast.CockroachDBJoin.JoinType; import sqlancer.cockroachdb.ast.CockroachDBMultiValuedComparison; import sqlancer.cockroachdb.ast.CockroachDBMultiValuedComparison.MultiValuedComparisonOperator; import sqlancer.cockroachdb.ast.CockroachDBMultiValuedComparison.MultiValuedComparisonType; @@ -36,14 +42,24 @@ import sqlancer.cockroachdb.ast.CockroachDBOrderingTerm; import sqlancer.cockroachdb.ast.CockroachDBRegexOperation; import sqlancer.cockroachdb.ast.CockroachDBRegexOperation.CockroachDBRegexOperator; +import sqlancer.cockroachdb.ast.CockroachDBSelect; +import sqlancer.cockroachdb.ast.CockroachDBTableReference; import sqlancer.cockroachdb.ast.CockroachDBTypeAnnotation; import sqlancer.cockroachdb.ast.CockroachDBUnaryPostfixOperation; import sqlancer.cockroachdb.ast.CockroachDBUnaryPostfixOperation.CockroachDBUnaryPostfixOperator; +import sqlancer.common.gen.CERTGenerator; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; import sqlancer.common.gen.TypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; -public class CockroachDBExpressionGenerator - extends TypedExpressionGenerator { +public class CockroachDBExpressionGenerator extends + TypedExpressionGenerator implements + NoRECGenerator, + TLPWhereGenerator, + CERTGenerator { + private List tables; private final CockroachDBGlobalState globalState; public CockroachDBExpressionGenerator(CockroachDBGlobalState globalState) { @@ -84,7 +100,8 @@ public List getOrderingTerms() { @Override public CockroachDBExpression generateExpression(CockroachDBCompositeDataType type, int depth) { - // if (type == CockroachDBDataType.FLOAT && Randomly.getBooleanWithRatherLowProbability()) { + // if (type == CockroachDBDataType.FLOAT && + // Randomly.getBooleanWithRatherLowProbability()) { // type = CockroachDBDataType.INT; // } if (allowAggregates && Randomly.getBoolean()) { @@ -358,4 +375,227 @@ public CockroachDBExpression isNull(CockroachDBExpression expr) { return new CockroachDBUnaryPostfixOperation(expr, CockroachDBUnaryPostfixOperator.IS_NULL); } + @Override + public CockroachDBExpressionGenerator setTablesAndColumns( + AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public CockroachDBExpression generateBooleanExpression() { + return generateExpression(CockroachDBDataType.BOOL.get()); + } + + @Override + public CockroachDBSelect generateSelect() { + return new CockroachDBSelect(); + } + + @Override + public List getRandomJoinClauses() { + List joinExpressions = new ArrayList<>(); + List tableReferences = tables.stream().map(t -> new CockroachDBTableReference(t)) + .collect(Collectors.toList()); + while (tableReferences.size() >= 2 && Randomly.getBoolean()) { + CockroachDBTableReference leftTable = tableReferences.remove(0); + CockroachDBTableReference rightTable = tableReferences.remove(0); + List columns = new ArrayList<>(leftTable.getTable().getColumns()); + columns.addAll(rightTable.getTable().getColumns()); + CockroachDBExpressionGenerator joinGen = new CockroachDBExpressionGenerator(globalState) + .setColumns(columns); + joinExpressions.add(CockroachDBJoin.createJoin(leftTable, rightTable, CockroachDBJoin.JoinType.getRandom(), + joinGen.generateExpression(CockroachDBDataType.BOOL.get()))); + } + + tables = tableReferences.stream().map(t -> t.getTable()).collect(Collectors.toList()); + return joinExpressions; + } + + @Override + public List getTableRefs() { + List tableReferences = tables.stream().map(t -> new CockroachDBTableReference(t)) + .collect(Collectors.toList()); + + return CockroachDBCommon.getTableReferences(tableReferences); + } + + @Override + public String generateOptimizedQueryString(CockroachDBSelect select, CockroachDBExpression whereCondition, + boolean shouldUseAggregate) { + CockroachDBColumn c = new CockroachDBColumn("COUNT(*)", null, false, false); + select.setWhereClause(whereCondition); + if (shouldUseAggregate) { + CockroachDBAggregate aggr = new CockroachDBAggregate(CockroachDBAggregateFunction.COUNT, + List.of(new CockroachDBColumnReference(new CockroachDBColumn("*", + new CockroachDBCompositeDataType(CockroachDBDataType.INT, 0), false, false)))); + select.setFetchColumns(List.of(aggr)); + } else { + select.setFetchColumns(List.of(new CockroachDBColumnReference(c))); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(getOrderingTerms()); + } + } + return CockroachDBVisitor.asString(select); + } + + @Override + public String generateUnoptimizedQueryString(CockroachDBSelect select, CockroachDBExpression whereCondition) { + List tableList = select.getFromList(); + List joinList = select.getJoinList(); + String fromString = tableList.stream().map(t -> ((CockroachDBTableReference) t).getTable().getName()) + .collect(Collectors.joining(", ")); + if (!tableList.isEmpty() && !joinList.isEmpty()) { + fromString += ", "; + } + return "SELECT SUM(count) FROM (SELECT CAST(" + CockroachDBVisitor.asString(whereCondition) + + " IS TRUE AS INT) as count FROM " + fromString + " " + + joinList.stream().map(j -> CockroachDBVisitor.asString(j)).collect(Collectors.joining(", ")) + ")"; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy || columns.isEmpty()) { + return List.of(new CockroachDBColumnReference(new CockroachDBColumn("*", null, false, false))); + } + return Randomly.nonEmptySubset(columns).stream().map(c -> new CockroachDBColumnReference(c)) + .collect(Collectors.toList()); + } + + @Override + public String generateExplainQuery(CockroachDBSelect select) { + return "EXPLAIN " + select.asString(); + } + + @Override + public boolean mutate(CockroachDBSelect select) { + List> mutators = new ArrayList<>(); + + if (!CockroachDBBugs.bug131647) { + mutators.add(this::mutateJoin); + } + mutators.add(this::mutateGroupBy); + mutators.add(this::mutateHaving); + mutators.add(this::mutateAnd); + if (!CockroachDBBugs.bug131640) { + mutators.add(this::mutateWhere); + mutators.add(this::mutateOr); + } + // mutators.add(this::mutateLimit); + mutators.add(this::mutateDistinct); + + return Randomly.fromList(mutators).apply(select); + } + + boolean mutateJoin(CockroachDBSelect select) { + if (select.getJoinList().isEmpty()) { + return false; + } + + CockroachDBJoin join = (CockroachDBJoin) Randomly.fromList(select.getJoinList()); + + // CROSS does not need ON Condition, while other joins do + // To avoid Null pointer, generating a new new condition when mutating CROSS to other joins + if (join.getJoinType() == JoinType.CROSS) { + List columns = new ArrayList<>(); + columns.addAll(((CockroachDBTableReference) join.getLeftTable()).getTable().getColumns()); + columns.addAll(((CockroachDBTableReference) join.getRightTable()).getTable().getColumns()); + CockroachDBExpressionGenerator joinGen2 = new CockroachDBExpressionGenerator(globalState) + .setColumns(columns); + join.setOnClause(joinGen2.generateExpression(CockroachDBDataType.BOOL.get())); + } + + JoinType newJoinType = CockroachDBJoin.JoinType.INNER; + if (join.getJoinType() == JoinType.LEFT || join.getJoinType() == JoinType.RIGHT) { // No invariant relation + // between LEFT and RIGHT + // join + newJoinType = CockroachDBJoin.JoinType.getRandomExcept(JoinType.NATURAL, JoinType.CROSS, JoinType.LEFT, + JoinType.RIGHT); + } else if (join.getJoinType() == JoinType.FULL) { + newJoinType = CockroachDBJoin.JoinType.getRandomExcept(JoinType.NATURAL, JoinType.CROSS); + } else if (join.getJoinType() != JoinType.CROSS) { + newJoinType = CockroachDBJoin.JoinType.getRandomExcept(JoinType.NATURAL, join.getJoinType()); + } + assert newJoinType != JoinType.NATURAL; // Natural Join is not supported for CERT + boolean increase = join.getJoinType().ordinal() < newJoinType.ordinal(); + join.setJoinType(newJoinType); + return increase; + } + + boolean mutateDistinct(CockroachDBSelect select) { + boolean increase = select.isDistinct(); + select.setDistinct(!select.isDistinct()); + return increase; + } + + boolean mutateWhere(CockroachDBSelect select) { + boolean increase = select.getWhereClause() != null; + if (increase) { + select.setWhereClause(null); + } else { + select.setWhereClause(generateExpression(CockroachDBDataType.BOOL.get())); + } + return increase; + } + + boolean mutateGroupBy(CockroachDBSelect select) { + boolean increase = !select.getGroupByExpressions().isEmpty(); + if (increase) { + select.clearGroupByExpressions(); + } else { + select.setGroupByExpressions(select.getFetchColumns()); + } + return increase; + } + + boolean mutateHaving(CockroachDBSelect select) { + if (select.getGroupByExpressions().isEmpty()) { + select.setGroupByExpressions(select.getFetchColumns()); + select.setHavingClause(generateExpression(CockroachDBDataType.BOOL.get())); + return false; + } else { + if (select.getHavingClause() == null) { + select.setHavingClause(generateExpression(CockroachDBDataType.BOOL.get())); + return false; + } else { + select.setHavingClause(null); + return true; + } + } + } + + boolean mutateAnd(CockroachDBSelect select) { + if (select.getWhereClause() == null) { + select.setWhereClause(generateExpression(CockroachDBDataType.BOOL.get())); + } else { + CockroachDBExpression newWhere = new CockroachDBBinaryLogicalOperation(select.getWhereClause(), + generateExpression(CockroachDBDataType.BOOL.get()), CockroachDBBinaryLogicalOperator.AND); + select.setWhereClause(newWhere); + } + return false; + } + + boolean mutateOr(CockroachDBSelect select) { + if (select.getWhereClause() == null) { + select.setWhereClause(generateExpression(CockroachDBDataType.BOOL.get())); + return false; + } else { + CockroachDBExpression newWhere = new CockroachDBBinaryLogicalOperation(select.getWhereClause(), + generateExpression(CockroachDBDataType.BOOL.get()), CockroachDBBinaryLogicalOperator.OR); + select.setWhereClause(newWhere); + return true; + } + } + + boolean mutateLimit(CockroachDBSelect select) { + boolean increase = select.getLimitClause() != null; + if (increase) { + select.setLimitClause(null); + } else { + select.setLimitClause(generateConstant(CockroachDBDataType.INT.get())); + } + return increase; + } } diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBIndexGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBIndexGenerator.java index e5b708338..4db3588ac 100644 --- a/src/sqlancer/cockroachdb/gen/CockroachDBIndexGenerator.java +++ b/src/sqlancer/cockroachdb/gen/CockroachDBIndexGenerator.java @@ -2,6 +2,7 @@ import java.util.List; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.cockroachdb.CockroachDBBugs; import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; @@ -17,6 +18,9 @@ public CockroachDBIndexGenerator(CockroachDBGlobalState globalState) { } public static SQLQueryAdapter create(CockroachDBGlobalState s) { + if (s.getSchema().getIndexCount() >= s.getDbmsSpecificOptions().maxNumIndexes) { + throw new IgnoreMeException(); + } return new CockroachDBIndexGenerator(s).getQuery(); } diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBInsertGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBInsertGenerator.java index 47de622d8..9f823c59b 100644 --- a/src/sqlancer/cockroachdb/gen/CockroachDBInsertGenerator.java +++ b/src/sqlancer/cockroachdb/gen/CockroachDBInsertGenerator.java @@ -18,6 +18,11 @@ private CockroachDBInsertGenerator() { } public static SQLQueryAdapter insert(CockroachDBGlobalState globalState) { + CockroachDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + return insert(globalState, table); + } + + public static SQLQueryAdapter insert(CockroachDBGlobalState globalState, CockroachDBTable table) { ExpectedErrors errors = new ExpectedErrors(); CockroachDBErrors.addExpressionErrors(errors); // e.g., caused by computed columns @@ -32,7 +37,6 @@ public static SQLQueryAdapter insert(CockroachDBGlobalState globalState) { errors.add("foreign key violation"); errors.add("multi-part foreign key"); StringBuilder sb = new StringBuilder(); - CockroachDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); boolean isUpsert = Randomly.getBoolean(); if (!isUpsert) { sb.append("INSERT INTO "); @@ -77,8 +81,6 @@ public static SQLQueryAdapter insert(CockroachDBGlobalState globalState) { if (Randomly.getBoolean()) { sb.append(" NOTHING "); } else { - // TODO: also support excluded. (see - // https://www.cockroachlabs.com/docs/stable/insert.html) sb.append(" UPDATE SET "); List columns = table.getRandomNonEmptyColumnSubset(); int i = 0; @@ -88,7 +90,12 @@ public static SQLQueryAdapter insert(CockroachDBGlobalState globalState) { } sb.append(c.getName()); sb.append(" = "); - sb.append(CockroachDBVisitor.asString(gen.generateConstant(c.getType()))); + if (Randomly.getBoolean()) { + sb.append(CockroachDBVisitor.asString(gen.generateConstant(c.getType()))); + } else { + sb.append("excluded."); + sb.append(c.getName()); + } } errors.add("UPSERT or INSERT...ON CONFLICT command cannot affect row a second time"); } diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBRandomQuerySynthesizer.java b/src/sqlancer/cockroachdb/gen/CockroachDBRandomQuerySynthesizer.java index d3679b015..880f33310 100644 --- a/src/sqlancer/cockroachdb/gen/CockroachDBRandomQuerySynthesizer.java +++ b/src/sqlancer/cockroachdb/gen/CockroachDBRandomQuerySynthesizer.java @@ -13,7 +13,7 @@ import sqlancer.cockroachdb.ast.CockroachDBExpression; import sqlancer.cockroachdb.ast.CockroachDBSelect; import sqlancer.cockroachdb.ast.CockroachDBTableReference; -import sqlancer.cockroachdb.oracle.CockroachDBNoRECOracle; +import sqlancer.cockroachdb.oracle.tlp.CockroachDBTLPBase; import sqlancer.common.query.SQLQueryAdapter; public final class CockroachDBRandomQuerySynthesizer { @@ -49,14 +49,14 @@ public static CockroachDBSelect generateSelect(CockroachDBGlobalState globalStat .map(t -> new CockroachDBTableReference(t)).collect(Collectors.toList()); List updatedTableList = CockroachDBCommon.getTableReferences(tableList); if (Randomly.getBoolean()) { - select.setJoinList(CockroachDBNoRECOracle.getJoins(updatedTableList, globalState)); + select.setJoinList(CockroachDBTLPBase.getJoins(updatedTableList, globalState)); } select.setFromList(updatedTableList); if (Randomly.getBoolean()) { select.setWhereClause(gen.generateExpression(CockroachDBDataType.BOOL.get())); } if (Randomly.getBoolean()) { - select.setOrderByExpressions(gen.getOrderingTerms()); + select.setOrderByClauses(gen.getOrderingTerms()); } if (Randomly.getBoolean()) { select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBSetClusterSettingGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBSetClusterSettingGenerator.java index 42740506d..a948a1237 100644 --- a/src/sqlancer/cockroachdb/gen/CockroachDBSetClusterSettingGenerator.java +++ b/src/sqlancer/cockroachdb/gen/CockroachDBSetClusterSettingGenerator.java @@ -17,10 +17,12 @@ private CockroachDBSetClusterSettingGenerator() { private enum CockroachDBClusterSetting { BACKPRESSURE_RANGE_SIZE_MULTIPLIER(" kv.range.backpressure_range_size_multiplier", (g) -> Randomly.getNotCachedInteger(0, Integer.MAX_VALUE)), - RANGE_DESCRIPTOR_CACHE_SIZE("kv.range_descriptor_cache.size", (g) -> Randomly.getNonCachedInteger()), + RANGE_DESCRIPTOR_CACHE_SIZE("kv.range_descriptor_cache.size", + (g) -> Randomly.getNotCachedInteger(0, Integer.MAX_VALUE)), SQL_QUERY_CACHE_ENABLED("sql.query_cache.enabled", CockroachDBSetSessionGenerator::onOff), SQL_STATS_HISTOGRAM_COLLECTION_ENABLED("sql.stats.histogram_collection.enabled", - CockroachDBSetSessionGenerator::onOff); + CockroachDBSetSessionGenerator::onOff), + HISTOGRAM_COLLECT("sql.stats.histogram_collection.enabled", CockroachDBSetSessionGenerator::onOff); private Function f; private String name; diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBShowGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBShowGenerator.java index 09abbfbd0..9b84d9c94 100644 --- a/src/sqlancer/cockroachdb/gen/CockroachDBShowGenerator.java +++ b/src/sqlancer/cockroachdb/gen/CockroachDBShowGenerator.java @@ -23,7 +23,7 @@ public static SQLQueryAdapter show(CockroachDBGlobalState globalState) { case EXPERIMENTAL_FINGERPRINTS: sb.append("SHOW EXPERIMENTAL_FINGERPRINTS FROM TABLE "); sb.append(globalState.getSchema().getRandomTable(t -> !t.isView()).getName()); - errors.add("as type bytes: bytea encoded value ends with incomplete escape sequence"); + errors.add("bytea encoded value ends with incomplete escape sequence"); errors.add("invalid bytea escape sequence"); break; case DATABASES: diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBTableGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBTableGenerator.java index 4e99ab31a..c8bfa3c6a 100644 --- a/src/sqlancer/cockroachdb/gen/CockroachDBTableGenerator.java +++ b/src/sqlancer/cockroachdb/gen/CockroachDBTableGenerator.java @@ -29,12 +29,16 @@ public CockroachDBTableGenerator(CockroachDBGlobalState globalState) { } public static SQLQueryAdapter generate(CockroachDBGlobalState globalState) { + if (globalState.getSchema().getDatabaseTables().size() > globalState.getDbmsSpecificOptions().maxNumTables) { + throw new IgnoreMeException(); + } return new CockroachDBTableGenerator(globalState).getQuery(); } @Override public void buildStatement() { errors.add("and thus is not indexable"); // array types are not indexable + errors.add("context-dependent operators are not allowed in STORED COMPUTED COLUMN"); if (globalState.getDbmsSpecificOptions().testTempTables) { errors.add("constraints on temporary tables may reference only temporary tables"); errors.add("constraints on permanent tables may reference only permanent tables"); diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBTruncateGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBTruncateGenerator.java index 904ee3a9f..44979487b 100644 --- a/src/sqlancer/cockroachdb/gen/CockroachDBTruncateGenerator.java +++ b/src/sqlancer/cockroachdb/gen/CockroachDBTruncateGenerator.java @@ -1,6 +1,7 @@ package sqlancer.cockroachdb.gen; import sqlancer.Randomly; +import sqlancer.cockroachdb.CockroachDBBugs; import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; @@ -14,6 +15,9 @@ private CockroachDBTruncateGenerator() { public static SQLQueryAdapter truncate(CockroachDBGlobalState globalState) { ExpectedErrors errors = new ExpectedErrors(); errors.add("is referenced by foreign key"); + if (CockroachDBBugs.bug85230) { + errors.add("found in depended-on-by references, no such index in this relation"); + } StringBuilder sb = new StringBuilder(); sb.append("TRUNCATE"); diff --git a/src/sqlancer/cockroachdb/gen/CockroachDBUpdateGenerator.java b/src/sqlancer/cockroachdb/gen/CockroachDBUpdateGenerator.java index b8a96e4c7..8dcd605d7 100644 --- a/src/sqlancer/cockroachdb/gen/CockroachDBUpdateGenerator.java +++ b/src/sqlancer/cockroachdb/gen/CockroachDBUpdateGenerator.java @@ -9,20 +9,27 @@ import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBDataType; import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBTable; import sqlancer.cockroachdb.CockroachDBVisitor; -import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.gen.AbstractUpdateGenerator; import sqlancer.common.query.SQLQueryAdapter; -public final class CockroachDBUpdateGenerator { +public final class CockroachDBUpdateGenerator extends AbstractUpdateGenerator { - private CockroachDBUpdateGenerator() { + private final CockroachDBGlobalState globalState; + private CockroachDBExpressionGenerator gen; + + private CockroachDBUpdateGenerator(CockroachDBGlobalState globalState) { + this.globalState = globalState; } public static SQLQueryAdapter gen(CockroachDBGlobalState globalState) { - ExpectedErrors errors = new ExpectedErrors(); + return new CockroachDBUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { CockroachDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); List columns = table.getRandomNonEmptyColumnSubset(); - CockroachDBExpressionGenerator gen = new CockroachDBExpressionGenerator(globalState).setColumns(columns); - StringBuilder sb = new StringBuilder("UPDATE "); + gen = new CockroachDBExpressionGenerator(globalState).setColumns(columns); + sb.append("UPDATE "); sb.append(table.getName()); if (Randomly.getBoolean()) { sb.append("@{FORCE_INDEX="); @@ -30,15 +37,7 @@ public static SQLQueryAdapter gen(CockroachDBGlobalState globalState) { sb.append("}"); } sb.append(" SET "); - int i = 0; - for (CockroachDBColumn c : columns) { - if (i++ != 0) { - sb.append(", "); - } - sb.append(c.getName()); - sb.append("="); - sb.append(CockroachDBVisitor.asString(gen.generateExpression(c.getType()))); - } + updateColumns(columns); if (Randomly.getBoolean()) { sb.append(" WHERE "); sb.append(CockroachDBVisitor.asString(gen.generateExpression(CockroachDBDataType.BOOL.get()))); @@ -55,4 +54,9 @@ public static SQLQueryAdapter gen(CockroachDBGlobalState globalState) { return new SQLQueryAdapter(sb.toString(), errors); } + @Override + protected void updateValue(CockroachDBColumn column) { + sb.append(CockroachDBVisitor.asString(gen.generateExpression(column.getType()))); + } + } diff --git a/src/sqlancer/cockroachdb/oracle/CockroachDBCERTOracle.java b/src/sqlancer/cockroachdb/oracle/CockroachDBCERTOracle.java new file mode 100644 index 000000000..7d2af5930 --- /dev/null +++ b/src/sqlancer/cockroachdb/oracle/CockroachDBCERTOracle.java @@ -0,0 +1,292 @@ +package sqlancer.cockroachdb.oracle; + +import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.SQLGlobalState; +import sqlancer.cockroachdb.CockroachDBBugs; +import sqlancer.cockroachdb.CockroachDBCommon; +import sqlancer.cockroachdb.CockroachDBErrors; +import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBColumn; +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBDataType; +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBTables; +import sqlancer.cockroachdb.CockroachDBVisitor; +import sqlancer.cockroachdb.ast.CockroachDBBinaryLogicalOperation; +import sqlancer.cockroachdb.ast.CockroachDBBinaryLogicalOperation.CockroachDBBinaryLogicalOperator; +import sqlancer.cockroachdb.ast.CockroachDBColumnReference; +import sqlancer.cockroachdb.ast.CockroachDBExpression; +import sqlancer.cockroachdb.ast.CockroachDBJoin; +import sqlancer.cockroachdb.ast.CockroachDBJoin.JoinType; +import sqlancer.cockroachdb.ast.CockroachDBSelect; +import sqlancer.cockroachdb.ast.CockroachDBTableReference; +import sqlancer.cockroachdb.gen.CockroachDBExpressionGenerator; +import sqlancer.common.DBMSCommon; +import sqlancer.common.oracle.CERTOracleBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLancerResultSet; + +public class CockroachDBCERTOracle extends CERTOracleBase + implements TestOracle { + private CockroachDBExpressionGenerator gen; + private CockroachDBSelect select; + + public CockroachDBCERTOracle(CockroachDBGlobalState globalState) { + super(globalState); + CockroachDBErrors.addExpressionErrors(errors); + } + + @Override + public void check() throws SQLException { + queryPlan1Sequences = new ArrayList<>(); + queryPlan2Sequences = new ArrayList<>(); + + // Randomly generate a query + CockroachDBTables tables = state.getSchema().getRandomTableNonEmptyTables(2); + List tableList = CockroachDBCommon.getTableReferences( + tables.getTables().stream().map(t -> new CockroachDBTableReference(t)).collect(Collectors.toList())); + gen = new CockroachDBExpressionGenerator(state).setColumns(tables.getColumns()); + List fetchColumns = new ArrayList<>(); + fetchColumns.addAll(Randomly.nonEmptySubset(tables.getColumns()).stream() + .map(c -> new CockroachDBColumnReference(c)).collect(Collectors.toList())); + select = new CockroachDBSelect(); + select.setFetchColumns(fetchColumns); + select.setFromList(tableList); + select.setDistinct(Randomly.getBoolean()); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(CockroachDBDataType.BOOL.get())); + } + if (Randomly.getBoolean()) { + select.setGroupByExpressions(fetchColumns); + if (Randomly.getBoolean()) { + select.setHavingClause(gen.generateExpression(CockroachDBDataType.BOOL.get())); + } + } + + // Set the join. + List joinExpressions = getJoins(tableList, state); + select.setJoinList(joinExpressions); + + // Get the result of the first query + String queryString1 = CockroachDBVisitor.asString(select); + int rowCount1 = getRow(state, queryString1, queryPlan1Sequences); + + List excludes = new ArrayList<>(); + // Disable limit due to its false positive + excludes.add(Mutator.LIMIT); + if (CockroachDBBugs.bug131640) { + excludes.add(Mutator.OR); + } + if (CockroachDBBugs.bug131647) { + excludes.add(Mutator.JOIN); + } + // Mutate the query + boolean increase = mutate(excludes.toArray(new Mutator[0])); + + // Get the result of the second query + String queryString2 = CockroachDBVisitor.asString(select); + int rowCount2 = getRow(state, queryString2, queryPlan2Sequences); + + // Check structural equivalence + if (DBMSCommon.editDistance(queryPlan1Sequences, queryPlan2Sequences) > 1) { + return; + } + + // Check the results + if (increase && rowCount1 > rowCount2 || !increase && rowCount1 < rowCount2) { + throw new AssertionError("Inconsistent result for query: EXPLAIN " + queryString1 + "; --" + rowCount1 + + "\nEXPLAIN " + queryString2 + "; --" + rowCount2); + } + } + + private List getJoins(List tableList, + CockroachDBGlobalState globalState) throws AssertionError { + List joinExpressions = new ArrayList<>(); + while (tableList.size() >= 2 && Randomly.getPercentage() < 0.8) { + CockroachDBTableReference leftTable = (CockroachDBTableReference) tableList.remove(0); + CockroachDBTableReference rightTable = (CockroachDBTableReference) tableList.remove(0); + List columns = new ArrayList<>(leftTable.getTable().getColumns()); + columns.addAll(rightTable.getTable().getColumns()); + CockroachDBExpressionGenerator joinGen = new CockroachDBExpressionGenerator(globalState) + .setColumns(columns); + joinExpressions.add(CockroachDBJoin.createJoin(leftTable, rightTable, + CockroachDBJoin.JoinType.getRandomExcept(JoinType.NATURAL), + joinGen.generateExpression(CockroachDBDataType.BOOL.get()))); + } + return joinExpressions; + } + + @Override + protected boolean mutateJoin() { + if (select.getJoinList().isEmpty()) { + return false; + } + + CockroachDBJoin join = (CockroachDBJoin) Randomly.fromList(select.getJoinList()); + + // CROSS does not need ON Condition, while other joins do + // To avoid Null pointer, generating a new new condition when mutating CROSS to other joins + if (join.getJoinType() == JoinType.CROSS) { + List columns = new ArrayList<>(); + columns.addAll(((CockroachDBTableReference) join.getLeftTable()).getTable().getColumns()); + columns.addAll(((CockroachDBTableReference) join.getRightTable()).getTable().getColumns()); + CockroachDBExpressionGenerator joinGen2 = new CockroachDBExpressionGenerator(state).setColumns(columns); + join.setOnClause(joinGen2.generateExpression(CockroachDBDataType.BOOL.get())); + } + + JoinType newJoinType = CockroachDBJoin.JoinType.INNER; + if (join.getJoinType() == JoinType.LEFT || join.getJoinType() == JoinType.RIGHT) { // No invariant relation + // between LEFT and RIGHT + // join + newJoinType = CockroachDBJoin.JoinType.getRandomExcept(JoinType.NATURAL, JoinType.CROSS, JoinType.LEFT, + JoinType.RIGHT); + } else if (join.getJoinType() == JoinType.FULL) { + newJoinType = CockroachDBJoin.JoinType.getRandomExcept(JoinType.NATURAL, JoinType.CROSS); + } else if (join.getJoinType() != JoinType.CROSS) { + newJoinType = CockroachDBJoin.JoinType.getRandomExcept(JoinType.NATURAL, join.getJoinType()); + } + assert newJoinType != JoinType.NATURAL; // Natural Join is not supported for CERT + boolean increase = join.getJoinType().ordinal() < newJoinType.ordinal(); + join.setJoinType(newJoinType); + return increase; + } + + @Override + protected boolean mutateDistinct() { + boolean increase = select.isDistinct(); + select.setDistinct(!select.isDistinct()); + return increase; + } + + @Override + protected boolean mutateWhere() { + boolean increase = select.getWhereClause() != null; + if (increase) { + select.setWhereClause(null); + } else { + select.setWhereClause(gen.generateExpression(CockroachDBDataType.BOOL.get())); + } + return increase; + } + + @Override + protected boolean mutateGroupBy() { + boolean increase = !select.getGroupByExpressions().isEmpty(); + if (increase) { + select.clearGroupByExpressions(); + } else { + select.setGroupByExpressions(select.getFetchColumns()); + } + return increase; + } + + @Override + protected boolean mutateHaving() { + if (select.getGroupByExpressions().isEmpty()) { + select.setGroupByExpressions(select.getFetchColumns()); + select.setHavingClause(gen.generateExpression(CockroachDBDataType.BOOL.get())); + return false; + } else { + if (select.getHavingClause() == null) { + select.setHavingClause(gen.generateExpression(CockroachDBDataType.BOOL.get())); + return false; + } else { + select.setHavingClause(null); + return true; + } + } + } + + @Override + protected boolean mutateAnd() { + if (select.getWhereClause() == null) { + select.setWhereClause(gen.generateExpression(CockroachDBDataType.BOOL.get())); + } else { + CockroachDBExpression newWhere = new CockroachDBBinaryLogicalOperation(select.getWhereClause(), + gen.generateExpression(CockroachDBDataType.BOOL.get()), CockroachDBBinaryLogicalOperator.AND); + select.setWhereClause(newWhere); + } + return false; + } + + @Override + protected boolean mutateOr() { + if (select.getWhereClause() == null) { + select.setWhereClause(gen.generateExpression(CockroachDBDataType.BOOL.get())); + return false; + } else { + CockroachDBExpression newWhere = new CockroachDBBinaryLogicalOperation(select.getWhereClause(), + gen.generateExpression(CockroachDBDataType.BOOL.get()), CockroachDBBinaryLogicalOperator.OR); + select.setWhereClause(newWhere); + return true; + } + } + + @Override + protected boolean mutateLimit() { + boolean increase = select.getLimitClause() != null; + if (increase) { + select.setLimitClause(null); + } else { + select.setLimitClause(gen.generateConstant(CockroachDBDataType.INT.get())); + } + return increase; + } + + private int getRow(SQLGlobalState globalState, String selectStr, List queryPlanSequences) + throws AssertionError, SQLException { + int row = -1; + String explainQuery = "EXPLAIN " + selectStr; + + // Log the query + if (globalState.getOptions().logEachSelect()) { + globalState.getLogger().writeCurrent(explainQuery); + try { + globalState.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + // Get the row count + SQLQueryAdapter q = new SQLQueryAdapter(explainQuery, errors); + try (SQLancerResultSet rs = q.executeAndGet(globalState)) { + if (rs != null) { + while (rs.next()) { + String content = rs.getString(1); + if (content.contains("count:")) { + try { + int number = Integer.parseInt(content.split("count: ")[1].split(" ")[0].replace(",", "")); + if (row == -1) { + row = number; + } + } catch (Exception e) { // To avoid the situation that no number is found + } + } + if (content.contains("• ")) { + String operation = content.split("• ")[1].split(" ")[0]; + if (CockroachDBBugs.bug131875 && (operation.equals("distinct") || operation.equals("limit"))) { + throw new IgnoreMeException(); + } + queryPlanSequences.add(operation); + } + } + } + } catch (IgnoreMeException e) { + throw new IgnoreMeException(); + } catch (Exception e) { + throw new AssertionError(q.getQueryString(), e); + } + if (row == -1) { + throw new IgnoreMeException(); + } + return row; + } + +} diff --git a/src/sqlancer/cockroachdb/oracle/CockroachDBNoRECOracle.java b/src/sqlancer/cockroachdb/oracle/CockroachDBNoRECOracle.java deleted file mode 100644 index 3020f3238..000000000 --- a/src/sqlancer/cockroachdb/oracle/CockroachDBNoRECOracle.java +++ /dev/null @@ -1,152 +0,0 @@ -package sqlancer.cockroachdb.oracle; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import sqlancer.IgnoreMeException; -import sqlancer.Randomly; -import sqlancer.SQLGlobalState; -import sqlancer.cockroachdb.CockroachDBCommon; -import sqlancer.cockroachdb.CockroachDBErrors; -import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; -import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBColumn; -import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBDataType; -import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBTables; -import sqlancer.cockroachdb.CockroachDBVisitor; -import sqlancer.cockroachdb.ast.CockroachDBColumnReference; -import sqlancer.cockroachdb.ast.CockroachDBExpression; -import sqlancer.cockroachdb.ast.CockroachDBJoin; -import sqlancer.cockroachdb.ast.CockroachDBJoin.OuterType; -import sqlancer.cockroachdb.ast.CockroachDBSelect; -import sqlancer.cockroachdb.ast.CockroachDBTableReference; -import sqlancer.cockroachdb.gen.CockroachDBExpressionGenerator; -import sqlancer.common.oracle.NoRECBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.common.query.ExpectedErrors; -import sqlancer.common.query.SQLQueryAdapter; -import sqlancer.common.query.SQLancerResultSet; - -public class CockroachDBNoRECOracle extends NoRECBase implements TestOracle { - - private CockroachDBExpressionGenerator gen; - - public CockroachDBNoRECOracle(CockroachDBGlobalState globalState) { - super(globalState); - CockroachDBErrors.addExpressionErrors(errors); - CockroachDBErrors.addTransactionErrors(errors); - errors.add("unable to vectorize execution plan"); // SET vectorize=experimental_always; - errors.add(" mismatched physical types at index"); // SET vectorize=experimental_always; - } - - @Override - public void check() throws SQLException { - CockroachDBTables tables = state.getSchema().getRandomTableNonEmptyTables(); - List tableL = tables.getTables().stream().map(t -> new CockroachDBTableReference(t)) - .collect(Collectors.toList()); - List tableList = CockroachDBCommon.getTableReferences(tableL); - gen = new CockroachDBExpressionGenerator(state).setColumns(tables.getColumns()); - List joinExpressions = getJoins(tableList, state); - CockroachDBExpression whereCondition = gen.generateExpression(CockroachDBDataType.BOOL.get()); - int optimizableCount = getOptimizedResult(whereCondition, tableList, errors, joinExpressions); - if (optimizableCount == -1) { - throw new IgnoreMeException(); - } - int nonOptimizableCount = getNonOptimizedResult(whereCondition, tableList, errors, joinExpressions); - if (nonOptimizableCount == -1) { - throw new IgnoreMeException(); - } - if (optimizableCount != nonOptimizableCount) { - state.getState().getLocalState().log(optimizedQueryString + ";\n" + unoptimizedQueryString + ";"); - throw new AssertionError(CockroachDBVisitor.asString(whereCondition)); - } - } - - public static List getJoins(List tableList, - CockroachDBGlobalState globalState) throws AssertionError { - List joinExpressions = new ArrayList<>(); - while (tableList.size() >= 2 && Randomly.getBoolean()) { - CockroachDBTableReference leftTable = (CockroachDBTableReference) tableList.remove(0); - CockroachDBTableReference rightTable = (CockroachDBTableReference) tableList.remove(0); - List columns = new ArrayList<>(leftTable.getTable().getColumns()); - columns.addAll(rightTable.getTable().getColumns()); - CockroachDBExpressionGenerator joinGen = new CockroachDBExpressionGenerator(globalState) - .setColumns(columns); - switch (CockroachDBJoin.JoinType.getRandom()) { - case INNER: - joinExpressions.add(new CockroachDBJoin(leftTable, rightTable, CockroachDBJoin.JoinType.INNER, - joinGen.generateExpression(CockroachDBDataType.BOOL.get()))); - break; - case NATURAL: - joinExpressions.add(CockroachDBJoin.createNaturalJoin(leftTable, rightTable)); - break; - case CROSS: - joinExpressions.add(CockroachDBJoin.createCrossJoin(leftTable, rightTable)); - break; - case OUTER: - joinExpressions.add(CockroachDBJoin.createOuterJoin(leftTable, rightTable, OuterType.getRandom(), - joinGen.generateExpression(CockroachDBDataType.BOOL.get()))); - break; - default: - throw new AssertionError(); - } - } - return joinExpressions; - } - - private int getOptimizedResult(CockroachDBExpression whereCondition, List tableList, - ExpectedErrors errors, List joinExpressions) throws SQLException { - CockroachDBSelect select = new CockroachDBSelect(); - CockroachDBColumn c = new CockroachDBColumn("COUNT(*)", null, false, false); - select.setFetchColumns(Arrays.asList(new CockroachDBColumnReference(c))); - select.setFromList(tableList); - select.setWhereClause(whereCondition); - select.setJoinList(joinExpressions); - if (Randomly.getBooleanWithRatherLowProbability()) { - select.setOrderByExpressions(gen.getOrderingTerms()); - } - String s = CockroachDBVisitor.asString(select); - if (state.getOptions().logEachSelect()) { - state.getLogger().writeCurrent(s); - } - this.optimizedQueryString = s; - SQLQueryAdapter q = new SQLQueryAdapter(s, errors); - return getCount(state, q); - } - - private int getNonOptimizedResult(CockroachDBExpression whereCondition, List tableList, - ExpectedErrors errors, List joinList) throws SQLException { - String fromString = tableList.stream().map(t -> ((CockroachDBTableReference) t).getTable().getName()) - .collect(Collectors.joining(", ")); - if (!tableList.isEmpty() && !joinList.isEmpty()) { - fromString += ", "; - } - String s = "SELECT SUM(count) FROM (SELECT CAST(" + CockroachDBVisitor.asString(whereCondition) - + " IS TRUE AS INT) as count FROM " + fromString + " " - + joinList.stream().map(j -> CockroachDBVisitor.asString(j)).collect(Collectors.joining(", ")) + ")"; - if (state.getOptions().logEachSelect()) { - state.getLogger().writeCurrent(s); - } - this.unoptimizedQueryString = s; - SQLQueryAdapter q = new SQLQueryAdapter(s, errors); - return getCount(state, q); - } - - private int getCount(SQLGlobalState globalState, SQLQueryAdapter q) throws AssertionError { - int count = 0; - try (SQLancerResultSet rs = q.executeAndGet(globalState)) { - if (rs == null) { - return -1; - } - if (rs.next()) { - count = rs.getInt(1); - } - } catch (Exception e) { - throw new AssertionError(q.getQueryString(), e); - } - return count; - } - -} diff --git a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPAggregateOracle.java b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPAggregateOracle.java index bc2601ef5..6dd4bd5d5 100644 --- a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPAggregateOracle.java +++ b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPAggregateOracle.java @@ -29,13 +29,12 @@ import sqlancer.cockroachdb.ast.CockroachDBUnaryPostfixOperation; import sqlancer.cockroachdb.ast.CockroachDBUnaryPostfixOperation.CockroachDBUnaryPostfixOperator; import sqlancer.cockroachdb.gen.CockroachDBExpressionGenerator; -import sqlancer.cockroachdb.oracle.CockroachDBNoRECOracle; import sqlancer.common.oracle.TestOracle; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLancerResultSet; -public class CockroachDBTLPAggregateOracle implements TestOracle { +public class CockroachDBTLPAggregateOracle implements TestOracle { private final CockroachDBGlobalState state; private final ExpectedErrors errors = new ExpectedErrors(); @@ -72,11 +71,11 @@ public void check() throws SQLException { .map(t -> new CockroachDBTableReference(t)).collect(Collectors.toList()); List from = CockroachDBCommon.getTableReferences(tableList); if (Randomly.getBooleanWithRatherLowProbability()) { - select.setJoinList(CockroachDBNoRECOracle.getJoins(from, state)); + select.setJoinList(CockroachDBTLPBase.getJoins(from, state)); } select.setFromList(from); if (Randomly.getBooleanWithRatherLowProbability()) { - select.setOrderByExpressions(gen.getOrderingTerms()); + select.setOrderByClauses(gen.getOrderingTerms()); } originalQuery = CockroachDBVisitor.asString(select); firstResult = getAggregateResult(originalQuery); @@ -85,8 +84,8 @@ public void check() throws SQLException { state.getState().getLocalState().log( "--" + originalQuery + ";\n--" + metamorphicQuery + "\n-- " + firstResult + "\n-- " + secondResult); - if (firstResult == null && secondResult != null - || firstResult != null && (!firstResult.contentEquals(secondResult) + if (firstResult == null && secondResult != null || firstResult != null && secondResult == null + || firstResult != null && secondResult != null && (!firstResult.contentEquals(secondResult) && !ComparatorHelper.isEqualDouble(firstResult, secondResult))) { if (secondResult.contains("Inf")) { throw new IgnoreMeException(); // FIXME: average computation @@ -147,13 +146,15 @@ private List mapped(CockroachDBAggregate aggregate) { case MIN: return aliasArgs(Arrays.asList(aggregate)); case AVG: - // List arg = Arrays.asList(new CockroachDBCast(aggregate.getExpr().get(0), + // List arg = Arrays.asList(new + // CockroachDBCast(aggregate.getExpr().get(0), // CockroachDBDataType.DECIMAL.get())); CockroachDBAggregate sum = new CockroachDBAggregate(CockroachDBAggregateFunction.SUM, aggregate.getExpr()); CockroachDBCast count = new CockroachDBCast( new CockroachDBAggregate(CockroachDBAggregateFunction.COUNT, aggregate.getExpr()), CockroachDBDataType.DECIMAL.get()); - // CockroachDBBinaryArithmeticOperation avg = new CockroachDBBinaryArithmeticOperation(sum, count, + // CockroachDBBinaryArithmeticOperation avg = new + // CockroachDBBinaryArithmeticOperation(sum, count, // CockroachDBBinaryArithmeticOperator.DIV); return aliasArgs(Arrays.asList(sum, count)); default: @@ -195,4 +196,9 @@ private CockroachDBSelect getSelect(List aggregates, List return leftSelect; } + @Override + public String getLastQueryString() { + return originalQuery; + } + } diff --git a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPBase.java b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPBase.java index 7e7f03e26..d84cab56f 100644 --- a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPBase.java +++ b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPBase.java @@ -10,20 +10,22 @@ import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; import sqlancer.cockroachdb.CockroachDBSchema; import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBColumn; +import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBDataType; import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBTable; import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBTables; import sqlancer.cockroachdb.ast.CockroachDBColumnReference; import sqlancer.cockroachdb.ast.CockroachDBExpression; +import sqlancer.cockroachdb.ast.CockroachDBJoin; import sqlancer.cockroachdb.ast.CockroachDBSelect; import sqlancer.cockroachdb.ast.CockroachDBTableReference; import sqlancer.cockroachdb.gen.CockroachDBExpressionGenerator; -import sqlancer.cockroachdb.oracle.CockroachDBNoRECOracle; import sqlancer.common.gen.ExpressionGenerator; import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; import sqlancer.common.oracle.TestOracle; -public class CockroachDBTLPBase extends - TernaryLogicPartitioningOracleBase implements TestOracle { +public class CockroachDBTLPBase + extends TernaryLogicPartitioningOracleBase + implements TestOracle { CockroachDBSchema s; CockroachDBTables targetTables; @@ -46,7 +48,7 @@ public void check() throws SQLException { List tables = targetTables.getTables(); List tableList = tables.stream().map(t -> new CockroachDBTableReference(t)) .collect(Collectors.toList()); - List joins = CockroachDBNoRECOracle.getJoins(tableList, state); + List joins = getJoins(tableList, state); select.setJoinList(joins); select.setFromList(tableList); select.setWhereClause(null); @@ -54,7 +56,7 @@ public void check() throws SQLException { List generateFetchColumns() { List columns = new ArrayList<>(); - if (Randomly.getBoolean() || targetTables.getColumns().size() == 0) { + if (Randomly.getBoolean() || targetTables.getColumns().isEmpty()) { columns.add(new CockroachDBColumnReference(new CockroachDBColumn("*", null, false, false))); } else { columns.addAll(Randomly.nonEmptySubset(targetTables.getColumns()).stream() @@ -68,4 +70,20 @@ protected ExpressionGenerator getGen() { return gen; } + public static List getJoins(List tableList, + CockroachDBGlobalState globalState) throws AssertionError { + List joinExpressions = new ArrayList<>(); + while (tableList.size() >= 2 && Randomly.getBoolean()) { + CockroachDBTableReference leftTable = (CockroachDBTableReference) tableList.remove(0); + CockroachDBTableReference rightTable = (CockroachDBTableReference) tableList.remove(0); + List columns = new ArrayList<>(leftTable.getTable().getColumns()); + columns.addAll(rightTable.getTable().getColumns()); + CockroachDBExpressionGenerator joinGen = new CockroachDBExpressionGenerator(globalState) + .setColumns(columns); + joinExpressions.add(CockroachDBJoin.createJoin(leftTable, rightTable, CockroachDBJoin.JoinType.getRandom(), + joinGen.generateExpression(CockroachDBDataType.BOOL.get()))); + } + return joinExpressions; + } + } diff --git a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPDistinctOracle.java b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPDistinctOracle.java index bdf8f59c3..dd174e805 100644 --- a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPDistinctOracle.java +++ b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPDistinctOracle.java @@ -15,6 +15,8 @@ public class CockroachDBTLPDistinctOracle extends CockroachDBTLPBase { + private String generatedQueryString; + public CockroachDBTLPDistinctOracle(CockroachDBGlobalState state) { super(state); errors.add("GROUP BY term out of range"); @@ -25,7 +27,7 @@ public void check() throws SQLException { super.check(); select.setDistinct(true); String originalQueryString = CockroachDBVisitor.asString(select); - + generatedQueryString = originalQueryString; List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); select.setDistinct(false); CockroachDBExpression predicate = gen.generateExpression(CockroachDBDataType.BOOL.get()); @@ -41,4 +43,9 @@ public void check() throws SQLException { ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, state); } + + @Override + public String getLastQueryString() { + return generatedQueryString; + } } diff --git a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPExtendedWhereOracle.java b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPExtendedWhereOracle.java index 36ac46733..863f056ab 100644 --- a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPExtendedWhereOracle.java +++ b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPExtendedWhereOracle.java @@ -19,6 +19,7 @@ public class CockroachDBTLPExtendedWhereOracle extends CockroachDBTLPBase { private CockroachDBExpression originalPredicate; + private String generatedQueryString; public CockroachDBTLPExtendedWhereOracle(CockroachDBGlobalState state) { super(state); @@ -32,11 +33,12 @@ public void check() throws SQLException { originalPredicate = generatePredicate(); select.setWhereClause(originalPredicate); String originalQueryString = CockroachDBVisitor.asString(select); + generatedQueryString = originalQueryString; List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); boolean allowOrderBy = Randomly.getBoolean(); if (allowOrderBy) { - select.setOrderByExpressions(gen.getOrderingTerms()); + select.setOrderByClauses(gen.getOrderingTerms()); } select.setWhereClause(combinePredicate(predicate)); String firstQueryString = CockroachDBVisitor.asString(select); @@ -56,4 +58,9 @@ public CockroachDBExpression combinePredicate(CockroachDBExpression expr) { return new CockroachDBBinaryLogicalOperation(originalPredicate, expr, CockroachDBBinaryLogicalOperator.AND); } + + @Override + public String getLastQueryString() { + return generatedQueryString; + } } diff --git a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPGroupByOracle.java b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPGroupByOracle.java index 535fea3e1..466252e56 100644 --- a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPGroupByOracle.java +++ b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPGroupByOracle.java @@ -14,6 +14,8 @@ public class CockroachDBTLPGroupByOracle extends CockroachDBTLPBase { + private String generatedQueryString; + public CockroachDBTLPGroupByOracle(CockroachDBGlobalState state) { super(state); } @@ -24,7 +26,7 @@ public void check() throws SQLException { select.setGroupByExpressions(select.getFetchColumns()); select.setWhereClause(null); String originalQueryString = CockroachDBVisitor.asString(select); - + generatedQueryString = originalQueryString; List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); select.setWhereClause(predicate); @@ -46,4 +48,9 @@ List generateFetchColumns() { .collect(Collectors.toList())); } + @Override + public String getLastQueryString() { + return generatedQueryString; + } + } diff --git a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPHavingOracle.java b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPHavingOracle.java index 42abad00b..7a7dd7779 100644 --- a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPHavingOracle.java +++ b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPHavingOracle.java @@ -13,6 +13,8 @@ public class CockroachDBTLPHavingOracle extends CockroachDBTLPBase { + private String generatedQueryString; + public CockroachDBTLPHavingOracle(CockroachDBGlobalState state) { super(state); errors.add("GROUP BY term out of range"); @@ -26,11 +28,12 @@ public void check() throws SQLException { } boolean orderBy = Randomly.getBoolean(); if (orderBy) { - select.setOrderByExpressions(gen.generateOrderBys()); + select.setOrderByClauses(gen.generateOrderBys()); } select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); select.setHavingClause(null); String originalQueryString = CockroachDBVisitor.asString(select); + generatedQueryString = originalQueryString; List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); CockroachDBExpression predicate = gen.generateExpression(CockroachDBDataType.BOOL.get()); @@ -52,4 +55,9 @@ protected CockroachDBExpression generatePredicate() { return gen.generateHavingClause(); } + @Override + public String getLastQueryString() { + return generatedQueryString; + } + } diff --git a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPWhereOracle.java b/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPWhereOracle.java deleted file mode 100644 index 8fe0899d0..000000000 --- a/src/sqlancer/cockroachdb/oracle/tlp/CockroachDBTLPWhereOracle.java +++ /dev/null @@ -1,48 +0,0 @@ -package sqlancer.cockroachdb.oracle.tlp; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.Randomly; -import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; -import sqlancer.cockroachdb.CockroachDBSchema.CockroachDBDataType; -import sqlancer.cockroachdb.CockroachDBVisitor; -import sqlancer.cockroachdb.ast.CockroachDBExpression; -import sqlancer.cockroachdb.ast.CockroachDBNotOperation; -import sqlancer.cockroachdb.ast.CockroachDBUnaryPostfixOperation; -import sqlancer.cockroachdb.ast.CockroachDBUnaryPostfixOperation.CockroachDBUnaryPostfixOperator; - -public class CockroachDBTLPWhereOracle extends CockroachDBTLPBase { - - public CockroachDBTLPWhereOracle(CockroachDBGlobalState state) { - super(state); - errors.add("GROUP BY term out of range"); - } - - @Override - public void check() throws SQLException { - super.check(); - String originalQueryString = CockroachDBVisitor.asString(select); - - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - boolean allowOrderBy = Randomly.getBoolean(); - if (allowOrderBy) { - select.setOrderByExpressions(gen.getOrderingTerms()); - } - CockroachDBExpression predicate = gen.generateExpression(CockroachDBDataType.BOOL.get()); - select.setWhereClause(predicate); - String firstQueryString = CockroachDBVisitor.asString(select); - select.setWhereClause(new CockroachDBNotOperation(predicate)); - String secondQueryString = CockroachDBVisitor.asString(select); - select.setWhereClause(new CockroachDBUnaryPostfixOperation(predicate, CockroachDBUnaryPostfixOperator.IS_NULL)); - String thirdQueryString = CockroachDBVisitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, !allowOrderBy, state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state); - } -} diff --git a/src/sqlancer/common/DBMSCommon.java b/src/sqlancer/common/DBMSCommon.java index 2c7531d10..4478be2b3 100644 --- a/src/sqlancer/common/DBMSCommon.java +++ b/src/sqlancer/common/DBMSCommon.java @@ -1,5 +1,6 @@ package sqlancer.common; +import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -27,4 +28,42 @@ public static boolean matchesIndexName(String indexName) { return matcher.matches(); } + public static int getMaxIndexInDoubleArray(double... doubleArray) { + int maxIndex = 0; + double maxValue = 0.0; + for (int j = 0; j < doubleArray.length; j++) { + double curReward = doubleArray[j]; + if (curReward > maxValue) { + maxIndex = j; + maxValue = curReward; + } + } + return maxIndex; + } + + public static boolean areQueryPlanSequencesSimilar(List list1, List list2) { + return editDistance(list1, list2) <= 1; + } + + public static int editDistance(List list1, List list2) { + int[][] dp = new int[list1.size() + 1][list2.size() + 1]; + for (int i = 0; i <= list1.size(); i++) { + for (int j = 0; j <= list2.size(); j++) { + if (i == 0) { + dp[i][j] = j; + } else if (j == 0) { + dp[i][j] = i; + } else { + dp[i][j] = Math.min(dp[i - 1][j - 1] + costOfSubstitution(list1.get(i - 1), list2.get(j - 1)), + Math.min(dp[i - 1][j] + 1, dp[i][j - 1] + 1)); + } + } + } + return dp[list1.size()][list2.size()]; + } + + private static int costOfSubstitution(String string, String string2) { + return string.equals(string2) ? 0 : 1; + } + } diff --git a/src/sqlancer/common/ast/BinaryNode.java b/src/sqlancer/common/ast/BinaryNode.java index dec1854c4..90a56a078 100644 --- a/src/sqlancer/common/ast/BinaryNode.java +++ b/src/sqlancer/common/ast/BinaryNode.java @@ -7,7 +7,7 @@ public abstract class BinaryNode implements BinaryOperation { private final T left; private final T right; - public BinaryNode(T left, T right) { + protected BinaryNode(T left, T right) { this.left = left; this.right = right; } diff --git a/src/sqlancer/common/ast/BinaryOperatorNode.java b/src/sqlancer/common/ast/BinaryOperatorNode.java index 38d340a4b..586315059 100644 --- a/src/sqlancer/common/ast/BinaryOperatorNode.java +++ b/src/sqlancer/common/ast/BinaryOperatorNode.java @@ -10,7 +10,7 @@ public interface Operator { String getTextRepresentation(); } - public BinaryOperatorNode(T left, T right, O op) { + protected BinaryOperatorNode(T left, T right, O op) { super(left, right); this.op = op; } diff --git a/src/sqlancer/common/ast/FunctionNode.java b/src/sqlancer/common/ast/FunctionNode.java index a841d2956..ad795b998 100644 --- a/src/sqlancer/common/ast/FunctionNode.java +++ b/src/sqlancer/common/ast/FunctionNode.java @@ -7,7 +7,7 @@ public abstract class FunctionNode { protected F function; protected List args; - public FunctionNode(F function, List args) { + protected FunctionNode(F function, List args) { this.function = function; this.args = args; } diff --git a/src/sqlancer/common/ast/SelectBase.java b/src/sqlancer/common/ast/SelectBase.java index 341b00209..e79a1a87d 100644 --- a/src/sqlancer/common/ast/SelectBase.java +++ b/src/sqlancer/common/ast/SelectBase.java @@ -36,6 +36,10 @@ public void setFromList(List fromList) { this.fromList = fromList; } + public void setFromTables(List tables) { + setFromList(tables); + } + public List getFromList() { if (fromList == null) { throw new IllegalStateException(); @@ -50,19 +54,23 @@ public void setGroupByExpressions(List groupByExpressions) { this.groupByExpressions = groupByExpressions; } + public void clearGroupByExpressions() { + this.groupByExpressions = Collections.emptyList(); + } + public List getGroupByExpressions() { assert groupByExpressions != null; return groupByExpressions; } - public void setOrderByExpressions(List orderByExpressions) { + public void setOrderByClauses(List orderByExpressions) { if (orderByExpressions == null) { throw new IllegalArgumentException(); } this.orderByExpressions = orderByExpressions; } - public List getOrderByExpressions() { + public List getOrderByClauses() { assert orderByExpressions != null; return orderByExpressions; } @@ -83,6 +91,10 @@ public T getHavingClause() { return havingClause; } + public void clearHavingClause() { + this.havingClause = null; + } + public void setLimitClause(T limitClause) { this.limitClause = limitClause; } @@ -107,4 +119,11 @@ public void setJoinList(List joinList) { this.joinList = joinList; } + public List getGroupByClause() { + return getGroupByExpressions(); + } + + public void setGroupByClause(List groupByExpressions) { + setGroupByExpressions(groupByExpressions); + } } diff --git a/src/sqlancer/common/ast/TernaryNode.java b/src/sqlancer/common/ast/TernaryNode.java index 146a46474..1a6605317 100644 --- a/src/sqlancer/common/ast/TernaryNode.java +++ b/src/sqlancer/common/ast/TernaryNode.java @@ -8,7 +8,7 @@ public abstract class TernaryNode implements BinaryOperation { private final T middle; private final T right; - public TernaryNode(T left, T middle, T right) { + protected TernaryNode(T left, T middle, T right) { this.left = left; this.middle = middle; this.right = right; diff --git a/src/sqlancer/common/ast/UnaryNode.java b/src/sqlancer/common/ast/UnaryNode.java index a6664c864..dea4220b9 100644 --- a/src/sqlancer/common/ast/UnaryNode.java +++ b/src/sqlancer/common/ast/UnaryNode.java @@ -6,7 +6,7 @@ public abstract class UnaryNode implements UnaryOperation { protected final T expr; - public UnaryNode(T expr) { + protected UnaryNode(T expr) { this.expr = expr; } diff --git a/src/sqlancer/common/ast/UnaryOperatorNode.java b/src/sqlancer/common/ast/UnaryOperatorNode.java index 979479c12..71556496c 100644 --- a/src/sqlancer/common/ast/UnaryOperatorNode.java +++ b/src/sqlancer/common/ast/UnaryOperatorNode.java @@ -6,7 +6,7 @@ public abstract class UnaryOperatorNode extends UnaryNode protected final O op; - public UnaryOperatorNode(T expr, O op) { + protected UnaryOperatorNode(T expr, O op) { super(expr); this.op = op; } diff --git a/src/sqlancer/common/ast/newast/ColumnReferenceNode.java b/src/sqlancer/common/ast/newast/ColumnReferenceNode.java index 18ff09db6..2d8688d8f 100644 --- a/src/sqlancer/common/ast/newast/ColumnReferenceNode.java +++ b/src/sqlancer/common/ast/newast/ColumnReferenceNode.java @@ -2,7 +2,7 @@ import sqlancer.common.schema.AbstractTableColumn; -public class ColumnReferenceNode> implements Node { +public class ColumnReferenceNode> { private final C c; diff --git a/src/sqlancer/common/ast/newast/Constant.java b/src/sqlancer/common/ast/newast/Constant.java new file mode 100644 index 000000000..a6a6dfd49 --- /dev/null +++ b/src/sqlancer/common/ast/newast/Constant.java @@ -0,0 +1,6 @@ +package sqlancer.common.ast.newast; + +import sqlancer.common.schema.AbstractTableColumn; + +public interface Constant> extends Expression { +} diff --git a/src/sqlancer/common/ast/newast/Expression.java b/src/sqlancer/common/ast/newast/Expression.java new file mode 100644 index 000000000..925cb82c5 --- /dev/null +++ b/src/sqlancer/common/ast/newast/Expression.java @@ -0,0 +1,6 @@ +package sqlancer.common.ast.newast; + +import sqlancer.common.schema.AbstractTableColumn; + +public interface Expression> { +} diff --git a/src/sqlancer/common/ast/newast/Join.java b/src/sqlancer/common/ast/newast/Join.java new file mode 100644 index 000000000..94d26c139 --- /dev/null +++ b/src/sqlancer/common/ast/newast/Join.java @@ -0,0 +1,10 @@ +package sqlancer.common.ast.newast; + +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; + +public interface Join, T extends AbstractTable, C extends AbstractTableColumn> + extends Expression { + + void setOnClause(E onClause); +} diff --git a/src/sqlancer/common/ast/newast/NewAliasNode.java b/src/sqlancer/common/ast/newast/NewAliasNode.java index 9d7d84c43..260900712 100644 --- a/src/sqlancer/common/ast/newast/NewAliasNode.java +++ b/src/sqlancer/common/ast/newast/NewAliasNode.java @@ -1,16 +1,16 @@ package sqlancer.common.ast.newast; -public class NewAliasNode implements Node { +public class NewAliasNode { - private final Node expr; + private final E expr; private final String alias; - public NewAliasNode(Node expr, String alias) { + public NewAliasNode(E expr, String alias) { this.expr = expr; this.alias = alias; } - public Node getExpr() { + public E getExpr() { return expr; } diff --git a/src/sqlancer/common/ast/newast/NewBetweenOperatorNode.java b/src/sqlancer/common/ast/newast/NewBetweenOperatorNode.java index c5a0acf27..4bb9ce17a 100644 --- a/src/sqlancer/common/ast/newast/NewBetweenOperatorNode.java +++ b/src/sqlancer/common/ast/newast/NewBetweenOperatorNode.java @@ -1,28 +1,28 @@ package sqlancer.common.ast.newast; -public class NewBetweenOperatorNode implements Node { +public class NewBetweenOperatorNode { - protected Node left; - protected Node middle; - protected Node right; + protected T left; + protected T middle; + protected T right; protected boolean isTrue; - public NewBetweenOperatorNode(Node left, Node middle, Node right, boolean isTrue) { + public NewBetweenOperatorNode(T left, T middle, T right, boolean isTrue) { this.left = left; this.middle = middle; this.right = right; this.isTrue = isTrue; } - public Node getLeft() { + public T getLeft() { return left; } - public Node getMiddle() { + public T getMiddle() { return middle; } - public Node getRight() { + public T getRight() { return right; } diff --git a/src/sqlancer/common/ast/newast/NewBinaryOperatorNode.java b/src/sqlancer/common/ast/newast/NewBinaryOperatorNode.java index 2b640ed0c..b2fcc5ef0 100644 --- a/src/sqlancer/common/ast/newast/NewBinaryOperatorNode.java +++ b/src/sqlancer/common/ast/newast/NewBinaryOperatorNode.java @@ -2,13 +2,13 @@ import sqlancer.common.ast.BinaryOperatorNode.Operator; -public class NewBinaryOperatorNode implements Node { +public class NewBinaryOperatorNode { protected final Operator op; - protected final Node left; - protected final Node right; + protected final T left; + protected final T right; - public NewBinaryOperatorNode(Node left, Node right, Operator op) { + public NewBinaryOperatorNode(T left, T right, Operator op) { this.left = left; this.right = right; this.op = op; @@ -18,11 +18,11 @@ public String getOperatorRepresentation() { return op.getTextRepresentation(); } - public Node getLeft() { + public T getLeft() { return left; } - public Node getRight() { + public T getRight() { return right; } diff --git a/src/sqlancer/common/ast/newast/NewCaseOperatorNode.java b/src/sqlancer/common/ast/newast/NewCaseOperatorNode.java index 04f6be4c5..440078566 100644 --- a/src/sqlancer/common/ast/newast/NewCaseOperatorNode.java +++ b/src/sqlancer/common/ast/newast/NewCaseOperatorNode.java @@ -2,15 +2,14 @@ import java.util.List; -public class NewCaseOperatorNode implements Node { +public class NewCaseOperatorNode { - private final List> conditions; - private final List> expressions; - private final Node elseExpr; - private final Node switchCondition; + private final List conditions; + private final List expressions; + private final T elseExpr; + private final T switchCondition; - public NewCaseOperatorNode(Node switchCondition, List> conditions, List> expressions, - Node elseExpr) { + public NewCaseOperatorNode(T switchCondition, List conditions, List expressions, T elseExpr) { this.switchCondition = switchCondition; this.conditions = conditions; this.expressions = expressions; @@ -20,19 +19,19 @@ public NewCaseOperatorNode(Node switchCondition, List> conditions, Li } } - public Node getSwitchCondition() { + public T getSwitchCondition() { return switchCondition; } - public List> getConditions() { + public List getConditions() { return conditions; } - public List> getExpressions() { + public List getExpressions() { return expressions; } - public Node getElseExpr() { + public T getElseExpr() { return elseExpr; } diff --git a/src/sqlancer/common/ast/newast/NewFunctionNode.java b/src/sqlancer/common/ast/newast/NewFunctionNode.java index fae214713..cdd91bd61 100644 --- a/src/sqlancer/common/ast/newast/NewFunctionNode.java +++ b/src/sqlancer/common/ast/newast/NewFunctionNode.java @@ -2,17 +2,17 @@ import java.util.List; -public class NewFunctionNode implements Node { +public class NewFunctionNode { - protected List> args; + protected List args; protected F func; - public NewFunctionNode(List> args, F func) { + public NewFunctionNode(List args, F func) { this.args = args; this.func = func; } - public List> getArgs() { + public List getArgs() { return args; } diff --git a/src/sqlancer/common/ast/newast/NewInOperatorNode.java b/src/sqlancer/common/ast/newast/NewInOperatorNode.java index 346156666..94a3a0886 100644 --- a/src/sqlancer/common/ast/newast/NewInOperatorNode.java +++ b/src/sqlancer/common/ast/newast/NewInOperatorNode.java @@ -2,23 +2,23 @@ import java.util.List; -public class NewInOperatorNode implements Node { +public class NewInOperatorNode { - private final Node left; - private final List> right; + private final T left; + private final List right; private final boolean isNegated; - public NewInOperatorNode(Node left, List> right, boolean isNegated) { + public NewInOperatorNode(T left, List right, boolean isNegated) { this.left = left; this.right = right; this.isNegated = isNegated; } - public Node getLeft() { + public T getLeft() { return left; } - public List> getRight() { + public List getRight() { return right; } diff --git a/src/sqlancer/common/ast/newast/NewOrderingTerm.java b/src/sqlancer/common/ast/newast/NewOrderingTerm.java index 5afe71821..2efed05bc 100644 --- a/src/sqlancer/common/ast/newast/NewOrderingTerm.java +++ b/src/sqlancer/common/ast/newast/NewOrderingTerm.java @@ -2,9 +2,9 @@ import sqlancer.Randomly; -public class NewOrderingTerm implements Node { +public class NewOrderingTerm { - private final Node expr; + private final T expr; private final Ordering ordering; public enum Ordering { @@ -15,12 +15,12 @@ public static Ordering getRandom() { } } - public NewOrderingTerm(Node expr, Ordering ordering) { + public NewOrderingTerm(T expr, Ordering ordering) { this.expr = expr; this.ordering = ordering; } - public Node getExpr() { + public T getExpr() { return expr; } diff --git a/src/sqlancer/common/ast/newast/NewPostfixTextNode.java b/src/sqlancer/common/ast/newast/NewPostfixTextNode.java index 70e65b17b..716dbf8e3 100644 --- a/src/sqlancer/common/ast/newast/NewPostfixTextNode.java +++ b/src/sqlancer/common/ast/newast/NewPostfixTextNode.java @@ -1,16 +1,16 @@ package sqlancer.common.ast.newast; -public class NewPostfixTextNode implements Node { +public class NewPostfixTextNode { - private final Node expr; + private final T expr; private final String text; - public NewPostfixTextNode(Node expr, String text) { + public NewPostfixTextNode(T expr, String text) { this.expr = expr; this.text = text; } - public Node getExpr() { + public T getExpr() { return expr; } diff --git a/src/sqlancer/common/ast/newast/NewTernaryNode.java b/src/sqlancer/common/ast/newast/NewTernaryNode.java index f40c480ec..2ca7b1f30 100644 --- a/src/sqlancer/common/ast/newast/NewTernaryNode.java +++ b/src/sqlancer/common/ast/newast/NewTernaryNode.java @@ -1,14 +1,14 @@ package sqlancer.common.ast.newast; -public class NewTernaryNode implements Node { +public class NewTernaryNode { - protected final Node left; - protected final Node middle; - protected final Node right; + protected final T left; + protected final T middle; + protected final T right; private final String leftStr; private final String rightStr; - public NewTernaryNode(Node left, Node middle, Node right, String leftStr, String rightStr) { + public NewTernaryNode(T left, T middle, T right, String leftStr, String rightStr) { this.left = left; this.middle = middle; this.right = right; @@ -16,15 +16,15 @@ public NewTernaryNode(Node left, Node middle, Node right, String leftSt this.rightStr = rightStr; } - public Node getLeft() { + public T getLeft() { return left; } - public Node getMiddle() { + public T getMiddle() { return middle; } - public Node getRight() { + public T getRight() { return right; } diff --git a/src/sqlancer/common/ast/newast/NewToStringVisitor.java b/src/sqlancer/common/ast/newast/NewToStringVisitor.java index 3ad625440..82b6bace2 100644 --- a/src/sqlancer/common/ast/newast/NewToStringVisitor.java +++ b/src/sqlancer/common/ast/newast/NewToStringVisitor.java @@ -7,7 +7,7 @@ public abstract class NewToStringVisitor { protected final StringBuilder sb = new StringBuilder(); @SuppressWarnings("unchecked") - public void visit(Node expr) { + public void visit(E expr) { assert expr != null; if (expr instanceof ColumnReferenceNode) { sb.append(((ColumnReferenceNode) expr).getColumn().getFullQualifiedName()); @@ -40,7 +40,7 @@ public void visit(Node expr) { } } - public void visit(List> expressions) { + public void visit(List expressions) { for (int i = 0; i < expressions.size(); i++) { if (i != 0) { sb.append(", "); @@ -165,6 +165,6 @@ public String get() { return sb.toString(); } - public abstract void visitSpecific(Node expr); + public abstract void visitSpecific(E expr); } diff --git a/src/sqlancer/common/ast/newast/NewUnaryPostfixOperatorNode.java b/src/sqlancer/common/ast/newast/NewUnaryPostfixOperatorNode.java index ebd17945f..b3ccd4cd3 100644 --- a/src/sqlancer/common/ast/newast/NewUnaryPostfixOperatorNode.java +++ b/src/sqlancer/common/ast/newast/NewUnaryPostfixOperatorNode.java @@ -2,12 +2,12 @@ import sqlancer.common.ast.BinaryOperatorNode.Operator; -public class NewUnaryPostfixOperatorNode implements Node { +public class NewUnaryPostfixOperatorNode { protected final Operator op; - private final Node expr; + private final T expr; - public NewUnaryPostfixOperatorNode(Node expr, Operator op) { + public NewUnaryPostfixOperatorNode(T expr, Operator op) { this.expr = expr; this.op = op; } @@ -16,8 +16,7 @@ public String getOperatorRepresentation() { return op.getTextRepresentation(); } - public Node getExpr() { + public T getExpr() { return expr; } - } diff --git a/src/sqlancer/common/ast/newast/NewUnaryPrefixOperatorNode.java b/src/sqlancer/common/ast/newast/NewUnaryPrefixOperatorNode.java index 5274bd942..8668aec35 100644 --- a/src/sqlancer/common/ast/newast/NewUnaryPrefixOperatorNode.java +++ b/src/sqlancer/common/ast/newast/NewUnaryPrefixOperatorNode.java @@ -2,12 +2,12 @@ import sqlancer.common.ast.BinaryOperatorNode.Operator; -public class NewUnaryPrefixOperatorNode implements Node { +public class NewUnaryPrefixOperatorNode { protected final Operator op; - private final Node expr; + private final T expr; - public NewUnaryPrefixOperatorNode(Node expr, Operator op) { + public NewUnaryPrefixOperatorNode(T expr, Operator op) { this.expr = expr; this.op = op; } @@ -16,7 +16,7 @@ public String getOperatorRepresentation() { return op.getTextRepresentation(); } - public Node getExpr() { + public T getExpr() { return expr; } diff --git a/src/sqlancer/common/ast/newast/Node.java b/src/sqlancer/common/ast/newast/Node.java deleted file mode 100644 index 310e875c0..000000000 --- a/src/sqlancer/common/ast/newast/Node.java +++ /dev/null @@ -1,5 +0,0 @@ -package sqlancer.common.ast.newast; - -public interface Node { - -} diff --git a/src/sqlancer/common/ast/newast/Select.java b/src/sqlancer/common/ast/newast/Select.java new file mode 100644 index 000000000..53520648d --- /dev/null +++ b/src/sqlancer/common/ast/newast/Select.java @@ -0,0 +1,48 @@ +package sqlancer.common.ast.newast; + +import java.util.List; + +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; + +public interface Select, E extends Expression, T extends AbstractTable, C extends AbstractTableColumn> + extends Expression { + + List getFromList(); + + void setFromList(List fromList); + + Expression getWhereClause(); + + void setWhereClause(E whereClause); + + void setGroupByClause(List groupByClause); + + List getGroupByClause(); + + void setLimitClause(E limitClause); + + Expression getLimitClause(); + + List getOrderByClauses(); + + void setOrderByClauses(List orderBy); + + void setOffsetClause(E offsetClause); + + Expression getOffsetClause(); + + void setFetchColumns(List fetchColumns); + + List getFetchColumns(); + + void setJoinClauses(List joinStatements); + + List getJoinClauses(); + + void setHavingClause(E havingClause); + + Expression getHavingClause(); + + String asString(); +} diff --git a/src/sqlancer/common/ast/newast/TableReferenceNode.java b/src/sqlancer/common/ast/newast/TableReferenceNode.java index 981da70ad..0719a0ed6 100644 --- a/src/sqlancer/common/ast/newast/TableReferenceNode.java +++ b/src/sqlancer/common/ast/newast/TableReferenceNode.java @@ -2,7 +2,7 @@ import sqlancer.common.schema.AbstractTable; -public class TableReferenceNode> implements Node { +public class TableReferenceNode> { private final T t; diff --git a/src/sqlancer/common/gen/AbstractInsertGenerator.java b/src/sqlancer/common/gen/AbstractInsertGenerator.java index 61357fd10..1a0b2a997 100644 --- a/src/sqlancer/common/gen/AbstractInsertGenerator.java +++ b/src/sqlancer/common/gen/AbstractInsertGenerator.java @@ -24,6 +24,6 @@ protected void insertColumns(List columns) { } } - protected abstract void insertValue(C tiDBColumn); + protected abstract void insertValue(C column); } diff --git a/src/sqlancer/common/gen/AbstractUpdateGenerator.java b/src/sqlancer/common/gen/AbstractUpdateGenerator.java new file mode 100644 index 000000000..f130c15a5 --- /dev/null +++ b/src/sqlancer/common/gen/AbstractUpdateGenerator.java @@ -0,0 +1,26 @@ +package sqlancer.common.gen; + +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.schema.AbstractTableColumn; + +public abstract class AbstractUpdateGenerator> { + + protected final ExpectedErrors errors = new ExpectedErrors(); + protected StringBuilder sb = new StringBuilder(); + + protected void updateColumns(List columns) { + for (int nrColumn = 0; nrColumn < columns.size(); nrColumn++) { + if (nrColumn != 0) { + sb.append(", "); + } + sb.append(columns.get(nrColumn).getName()); + sb.append("="); + updateValue(columns.get(nrColumn)); + } + } + + protected abstract void updateValue(C column); + +} diff --git a/src/sqlancer/common/gen/CERTGenerator.java b/src/sqlancer/common/gen/CERTGenerator.java new file mode 100644 index 000000000..b272ba4b2 --- /dev/null +++ b/src/sqlancer/common/gen/CERTGenerator.java @@ -0,0 +1,29 @@ +package sqlancer.common.gen; + +import java.util.List; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.common.ast.newast.Join; +import sqlancer.common.ast.newast.Select; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; + +public interface CERTGenerator, J extends Join, E extends Expression, T extends AbstractTable, C extends AbstractTableColumn> { + + CERTGenerator setTablesAndColumns(AbstractTables tables); + + E generateBooleanExpression(); + + S generateSelect(); + + List getRandomJoinClauses(); + + List getTableRefs(); + + List generateFetchColumns(boolean shouldCreateDummy); + + String generateExplainQuery(S select); + + boolean mutate(S select); +} diff --git a/src/sqlancer/common/gen/NoRECGenerator.java b/src/sqlancer/common/gen/NoRECGenerator.java new file mode 100644 index 000000000..185e8286b --- /dev/null +++ b/src/sqlancer/common/gen/NoRECGenerator.java @@ -0,0 +1,49 @@ +package sqlancer.common.gen; + +import java.util.List; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.common.ast.newast.Join; +import sqlancer.common.ast.newast.Select; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; + +public interface NoRECGenerator, J extends Join, E extends Expression, T extends AbstractTable, C extends AbstractTableColumn> { + + NoRECGenerator setTablesAndColumns(AbstractTables tables); + + E generateBooleanExpression(); + + S generateSelect(); + + List getRandomJoinClauses(); + + List getTableRefs(); + + /** + * Generates a query string that is likely to be optimized by the DBMS. + * + * @param select + * the base select expression used to generate the query + * @param whereCondition + * a condition where records will be checked with + * @param shouldUseAggregate + * whether to aggregate the record counts (`true`) or display records as is (`false`) + * + * @return a query string to be executed + */ + String generateOptimizedQueryString(S select, E whereCondition, boolean shouldUseAggregate); + + /** + * Generates a query string that is unlikely to be optimized by the DBMS. + * + * @param select + * the base select expression used to generate the query + * @param whereCondition + * the condition each record will be checked with + * + * @return a query string to be executed + */ + String generateUnoptimizedQueryString(S select, E whereCondition); +} diff --git a/src/sqlancer/common/gen/PartitionGenerator.java b/src/sqlancer/common/gen/PartitionGenerator.java new file mode 100644 index 000000000..affc62c42 --- /dev/null +++ b/src/sqlancer/common/gen/PartitionGenerator.java @@ -0,0 +1,27 @@ +package sqlancer.common.gen; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.common.schema.AbstractTableColumn; + +public interface PartitionGenerator, C extends AbstractTableColumn> { + + /** + * Negates a predicate (i.e., uses a NOT operator). + * + * @param predicate + * the boolean predicate. + * + * @return the negated predicate. + */ + E negatePredicate(E predicate); + + /** + * Checks if an expression evaluates to NULL (i.e., implements the IS NULL operator). + * + * @param expr + * the expression + * + * @return an expression that checks whether the expression evaluates to NULL. + */ + E isNull(E expr); +} diff --git a/src/sqlancer/common/gen/TLPWhereGenerator.java b/src/sqlancer/common/gen/TLPWhereGenerator.java new file mode 100644 index 000000000..095a878f9 --- /dev/null +++ b/src/sqlancer/common/gen/TLPWhereGenerator.java @@ -0,0 +1,28 @@ +package sqlancer.common.gen; + +import java.util.List; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.common.ast.newast.Join; +import sqlancer.common.ast.newast.Select; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; + +public interface TLPWhereGenerator, J extends Join, E extends Expression, T extends AbstractTable, C extends AbstractTableColumn> + extends PartitionGenerator { + + TLPWhereGenerator setTablesAndColumns(AbstractTables tables); + + E generateBooleanExpression(); + + S generateSelect(); + + List getRandomJoinClauses(); + + List getTableRefs(); + + List generateFetchColumns(boolean shouldCreateDummy); + + List generateOrderBys(); +} diff --git a/src/sqlancer/common/log/Loggable.java b/src/sqlancer/common/log/Loggable.java index 0d5fef218..7796009ee 100644 --- a/src/sqlancer/common/log/Loggable.java +++ b/src/sqlancer/common/log/Loggable.java @@ -1,5 +1,7 @@ package sqlancer.common.log; -public interface Loggable { +import java.io.Serializable; + +public interface Loggable extends Serializable { String getLogString(); } diff --git a/src/sqlancer/common/log/LoggedString.java b/src/sqlancer/common/log/LoggedString.java index 4f449034f..696203a78 100644 --- a/src/sqlancer/common/log/LoggedString.java +++ b/src/sqlancer/common/log/LoggedString.java @@ -1,6 +1,7 @@ package sqlancer.common.log; public class LoggedString implements Loggable { + private static final long serialVersionUID = 1L; private final String loggedString; diff --git a/src/sqlancer/common/log/SQLLoggableFactory.java b/src/sqlancer/common/log/SQLLoggableFactory.java index 24a28a603..bdcf2253f 100644 --- a/src/sqlancer/common/log/SQLLoggableFactory.java +++ b/src/sqlancer/common/log/SQLLoggableFactory.java @@ -14,7 +14,9 @@ protected Loggable createLoggable(String input, String suffix) { if (!input.endsWith(";")) { completeString += ";"; } - if (suffix != null && suffix.length() != 0) { + completeString = completeString.replace("\n", "\\n"); + completeString = completeString.replace("\r", "\\r"); + if (suffix != null && !suffix.isEmpty()) { completeString += suffix; } return new LoggedString(completeString); @@ -35,10 +37,10 @@ public SQLQueryAdapter commentOutQuery(Query query) { @Override protected Loggable infoToLoggable(String time, String databaseName, String databaseVersion, long seedValue) { StringBuilder sb = new StringBuilder(); - sb.append("-- Time: " + time + "\n"); - sb.append("-- Database: " + databaseName + "\n"); - sb.append("-- Database version: " + databaseVersion + "\n"); - sb.append("-- seed value: " + seedValue + "\n"); + sb.append("-- Time: ").append(time).append("\n"); + sb.append("-- Database: ").append(databaseName).append("\n"); + sb.append("-- Database version: ").append(databaseVersion).append("\n"); + sb.append("-- seed value: ").append(seedValue).append("\n"); return new LoggedString(sb.toString()); } diff --git a/src/sqlancer/common/oracle/CERTOracle.java b/src/sqlancer/common/oracle/CERTOracle.java new file mode 100644 index 000000000..48c174ef2 --- /dev/null +++ b/src/sqlancer/common/oracle/CERTOracle.java @@ -0,0 +1,133 @@ +package sqlancer.common.oracle; + +import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.SQLGlobalState; +import sqlancer.common.DBMSCommon; +import sqlancer.common.ast.newast.Expression; +import sqlancer.common.ast.newast.Join; +import sqlancer.common.ast.newast.Select; +import sqlancer.common.gen.CERTGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; + +public class CERTOracle, J extends Join, E extends Expression, S extends AbstractSchema, T extends AbstractTable, C extends AbstractTableColumn, G extends SQLGlobalState> + implements TestOracle { + + private final G state; + private final CheckedFunction> rowCountParser; + private final CheckedFunction> queryPlanParser; + + private CERTGenerator gen; + private final ExpectedErrors errors; + + public CERTOracle(G state, CERTGenerator gen, ExpectedErrors expectedErrors, + CheckedFunction> rowCountParser, + CheckedFunction> queryPlanParser) { + if (state == null || gen == null || expectedErrors == null) { + throw new IllegalArgumentException("Null variables used to initialize test oracle."); + } + this.state = state; + this.gen = gen; + this.errors = expectedErrors; + this.rowCountParser = rowCountParser; + this.queryPlanParser = queryPlanParser; + } + + @Override + public void check() throws SQLException { + S schema = state.getSchema(); + AbstractTables targetTables = TestOracleUtils.getRandomTableNonEmptyTables(schema); + gen = gen.setTablesAndColumns(targetTables); + + List fetchColumns = gen.generateFetchColumns(false); + + Z select = gen.generateSelect(); + select.setFetchColumns(fetchColumns); + select.setJoinClauses(gen.getRandomJoinClauses()); + select.setFromList(gen.getTableRefs()); + + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateBooleanExpression()); + } + if (Randomly.getBoolean()) { + select.setGroupByClause(fetchColumns); + if (Randomly.getBoolean()) { + select.setHavingClause(gen.generateBooleanExpression()); + } + } + + List queryPlan1Sequences = new ArrayList<>(); + List queryPlan2Sequences = new ArrayList<>(); + + String queryString1 = gen.generateExplainQuery(select); + long rowCount1 = getRow(state, queryString1, queryPlan1Sequences); + + boolean increase = gen.mutate(select); + String queryString2 = gen.generateExplainQuery(select); + long rowCount2 = getRow(state, queryString2, queryPlan2Sequences); + + if (DBMSCommon.editDistance(queryPlan1Sequences, queryPlan2Sequences) > 1) { + return; + } + + // Check the results + if (increase && rowCount1 > rowCount2 || !increase && rowCount1 < rowCount2) { + throw new AssertionError("Inconsistent result for query: " + queryString1 + "; --" + rowCount1 + "\n" + + queryString2 + "; --" + rowCount2); + } + } + + private Long getRow(SQLGlobalState globalState, String explainQuery, List queryPlanSequences) + throws AssertionError, SQLException { + Optional row = Optional.empty(); + + // Log the query + if (globalState.getOptions().logEachSelect()) { + globalState.getLogger().writeCurrent(explainQuery); + try { + globalState.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + // Get the row count + SQLQueryAdapter q = new SQLQueryAdapter(explainQuery, errors); + try (SQLancerResultSet rs = q.executeAndGet(globalState)) { + if (rs != null) { + while (rs.next()) { + Optional rowCount = rowCountParser.apply(rs); + if (row.isEmpty() && rowCount.isPresent()) { + row = rowCount; + } + + Optional queryPlanSequence = queryPlanParser.apply(rs); + queryPlanSequence.ifPresent(qps -> queryPlanSequences.add(qps)); + } + } + } catch (IgnoreMeException e) { + throw new IgnoreMeException(); + } catch (Exception e) { + throw new AssertionError(q.getQueryString(), e); + } + + return row.orElseThrow(IgnoreMeException::new); + } + + @FunctionalInterface + public interface CheckedFunction { + R apply(T t) throws SQLException; + } +} diff --git a/src/sqlancer/common/oracle/CERTOracleBase.java b/src/sqlancer/common/oracle/CERTOracleBase.java new file mode 100644 index 000000000..42e8e5833 --- /dev/null +++ b/src/sqlancer/common/oracle/CERTOracleBase.java @@ -0,0 +1,88 @@ +package sqlancer.common.oracle; + +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.SQLGlobalState; +import sqlancer.common.query.ExpectedErrors; + +public abstract class CERTOracleBase> implements TestOracle { + + protected final S state; + protected final ExpectedErrors errors; + protected List queryPlan1Sequences; + protected List queryPlan2Sequences; + + protected enum Mutator { + JOIN, DISTINCT, WHERE, GROUPBY, HAVING, AND, OR, LIMIT; + + public static Mutator getRandomExcept(Mutator... exclude) { + Mutator[] values = Arrays.stream(values()).filter(m -> !Arrays.asList(exclude).contains(m)) + .toArray(Mutator[]::new); + return Randomly.fromOptions(values); + } + } + + protected CERTOracleBase(S state) { + this.state = state; + this.errors = new ExpectedErrors(); + } + + protected boolean mutate(Mutator... exclude) { + Mutator m = Mutator.getRandomExcept(exclude); + switch (m) { + case JOIN: + return mutateJoin(); + case DISTINCT: + return mutateDistinct(); + case WHERE: + return mutateWhere(); + case GROUPBY: + return mutateGroupBy(); + case HAVING: + return mutateHaving(); + case AND: + return mutateAnd(); + case OR: + return mutateOr(); + case LIMIT: + return mutateLimit(); + default: + throw new AssertionError(m); + } + } + + protected boolean mutateJoin() { + throw new UnsupportedOperationException(); + } + + protected boolean mutateDistinct() { + throw new UnsupportedOperationException(); + } + + protected boolean mutateWhere() { + throw new UnsupportedOperationException(); + } + + protected boolean mutateGroupBy() { + throw new UnsupportedOperationException(); + } + + protected boolean mutateHaving() { + throw new UnsupportedOperationException(); + } + + protected boolean mutateAnd() { + throw new UnsupportedOperationException(); + } + + protected boolean mutateOr() { + throw new UnsupportedOperationException(); + } + + protected boolean mutateLimit() { + throw new UnsupportedOperationException(); + } + +} diff --git a/src/sqlancer/common/oracle/CODDTestBase.java b/src/sqlancer/common/oracle/CODDTestBase.java new file mode 100644 index 000000000..639a4f077 --- /dev/null +++ b/src/sqlancer/common/oracle/CODDTestBase.java @@ -0,0 +1,25 @@ +package sqlancer.common.oracle; + +import sqlancer.Main.StateLogger; +import sqlancer.MainOptions; +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; +import sqlancer.common.query.ExpectedErrors; + +public abstract class CODDTestBase> implements TestOracle { + protected final S state; + protected final ExpectedErrors errors = new ExpectedErrors(); + protected final StateLogger logger; + protected final MainOptions options; + protected final SQLConnection con; + protected String auxiliaryQueryString; + protected String foldedQueryString; + protected String originalQueryString; + + public CODDTestBase(S state) { + this.state = state; + this.con = state.getConnection(); + this.logger = state.getLogger(); + this.options = state.getOptions(); + } +} diff --git a/src/sqlancer/common/oracle/CompositeTestOracle.java b/src/sqlancer/common/oracle/CompositeTestOracle.java index fbb66960a..e2f96f785 100644 --- a/src/sqlancer/common/oracle/CompositeTestOracle.java +++ b/src/sqlancer/common/oracle/CompositeTestOracle.java @@ -4,27 +4,34 @@ import sqlancer.GlobalState; -public class CompositeTestOracle implements TestOracle { +public class CompositeTestOracle> implements TestOracle { - private final TestOracle[] oracles; - private final GlobalState globalState; + private final List> oracles; + private final G globalState; private int i; + private int iLast; - public CompositeTestOracle(List oracles, GlobalState globalState) { + public CompositeTestOracle(List> oracles, G globalState) { this.globalState = globalState; - this.oracles = oracles.toArray(new TestOracle[oracles.size()]); + this.oracles = oracles; } @Override public void check() throws Exception { try { - oracles[i].check(); - boolean lastOracleIndex = i == oracles.length - 1; + oracles.get(i).check(); + iLast = i; + boolean lastOracleIndex = i == oracles.size() - 1; if (!lastOracleIndex) { globalState.getManager().incrementSelectQueryCount(); } } finally { - i = (i + 1) % oracles.length; + i = (i + 1) % oracles.size(); } } + + @Override + public String getLastQueryString() { + return oracles.get(iLast).getLastQueryString(); + } } diff --git a/src/sqlancer/common/oracle/DQEBase.java b/src/sqlancer/common/oracle/DQEBase.java new file mode 100644 index 000000000..0d954b115 --- /dev/null +++ b/src/sqlancer/common/oracle/DQEBase.java @@ -0,0 +1,145 @@ +package sqlancer.common.oracle; + +import java.sql.SQLException; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import sqlancer.Main; +import sqlancer.MainOptions; +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryError; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractTables; + +/* + * In DBMSs, SELECT, UPDATE and DELETE queries utilize predicates (i.e., WHERE clauses) to specify which rows to retrieve, update or delete, respectively. + * If they use the same predicate φ, they should access the same rows in a database. + * Ideally, DBMSs can adopt the same implementations for predicate evaluation in SELECT, UPDATE and DELETE queries. + * However, a DBMS usually adopts different implementations for predicate evaluation in SELECT, UPDATE and DELETE queries due to various optimization choices. + * Inconsistent implementations for predicate evaluation among these queries can cause SELECT, UPDATE and DELETE queries with the same predicate φ to access different rows. + * + * + * Inspired by this key observation, we propose Differential Query Execution(DQE), a novel and general approach to detect logic bugs in SELECT, UPDATE and DELETE queries. + * DQE solves the test oracle problem by executing SELECT, UPDATE and DELETE queries with the same predicate φ, and observing inconsistencies among their execution results. + * For example, if a row that is updated by an UPDATE query with a predicate φ does not appear in the query result of a SELECT query with the same predicate φ, a logic bug is detected in the target DBMS. + * The key challenge of DQE is to automatically obtain the accessed rows for a given SELECT, UPDATE or DELETE query. + * To address this challenge, we append two extra columns to each table in a database, to uniquely identify each row and track whether a row has been modified, respectively. + * We further rewrite SELECT and UPDATE queries to identify their accessed rows. + * + * more information see [DQE paper](https://ieeexplore.ieee.org/document/10172736) + */ + +public abstract class DQEBase> { + + public static final String COLUMN_ROWID = "rowId"; + public static final String COLUMN_UPDATED = "updated"; + + protected final S state; + protected final ExpectedErrors selectExpectedErrors = new ExpectedErrors(); + protected final ExpectedErrors updateExpectedErrors = new ExpectedErrors(); + protected final ExpectedErrors deleteExpectedErrors = new ExpectedErrors(); + + protected final Main.StateLogger logger; + protected final MainOptions options; + protected final SQLConnection con; + + public DQEBase(S state) { + this.state = state; + this.con = state.getConnection(); + this.logger = state.getLogger(); + this.options = state.getOptions(); + } + + public abstract String generateSelectStatement(AbstractTables tables, String tableName, + String whereClauseStr); + + public abstract String generateUpdateStatement(AbstractTables tables, String tableName, + String whereClauseStr); + + public abstract String generateDeleteStatement(String tableName, String whereClauseStr); + + // Add auxiliary columns to the database A abstract method, subclasses need to implement it. + public abstract void addAuxiliaryColumns(AbstractRelationalTable table) throws SQLException; + + public void dropAuxiliaryColumns(AbstractRelationalTable table) throws SQLException { + String tableName = table.getName(); + String dropColumnRowId = String.format("ALTER TABLE %s DROP COLUMN %s", tableName, COLUMN_ROWID); + new SQLQueryAdapter(dropColumnRowId).execute(state); + String dropColumnUpdated = String.format("ALTER TABLE %s DROP COLUMN %s", tableName, COLUMN_UPDATED); + new SQLQueryAdapter(dropColumnUpdated).execute(state); + } + + // This interface is to record Error code + public interface UpdateErrorCodes { + + } + + public interface ErrorCodeStrategy { + Set getUpdateSpecificErrorCodes(); + + Set getDeleteSpecificErrorCodes(); + + } + + /** + * The core idea of DQE is that the SELECT, UPDATE and DELETE queries with the same predicate φ should access the + * same rows. If these queries access different rows, DQE reveals a potential logic bug in the target DBMS. + */ + public static class SQLQueryResult { + + private final Map, Set> accessedRows; // Table name with respect rows + private final List queryErrors; + + public SQLQueryResult(Map, Set> accessedRows, + List queryErrors) { + this.accessedRows = accessedRows; + this.queryErrors = queryErrors; + } + + public Map, Set> getAccessedRows() { + return accessedRows; + } + + public List getQueryErrors() { + return queryErrors; + } + + public boolean hasEmptyErrors() { + return queryErrors.isEmpty(); + } + + public boolean hasSameErrors(SQLQueryResult that) { + if (queryErrors.size() != that.getQueryErrors().size()) { + return false; + } else { + for (int i = 0; i < queryErrors.size(); i++) { + if (!queryErrors.get(i).equals(that.getQueryErrors().get(i))) { + return false; + } + } + } + return true; + } + + public boolean hasAccessedRows() { + if (accessedRows.isEmpty()) { + return false; + } + for (Set accessedRow : accessedRows.values()) { + if (!accessedRow.isEmpty()) { + return true; + } + } + return false; + } + + public boolean hasSameAccessedRows(SQLQueryResult that) { + return accessedRows.equals(that.getAccessedRows()); + } + + } +} diff --git a/src/sqlancer/common/oracle/DocumentRemovalOracleBase.java b/src/sqlancer/common/oracle/DocumentRemovalOracleBase.java index a48a6f17e..b6c0ee509 100644 --- a/src/sqlancer/common/oracle/DocumentRemovalOracleBase.java +++ b/src/sqlancer/common/oracle/DocumentRemovalOracleBase.java @@ -3,7 +3,7 @@ import sqlancer.GlobalState; import sqlancer.common.gen.ExpressionGenerator; -public abstract class DocumentRemovalOracleBase> implements TestOracle { +public abstract class DocumentRemovalOracleBase> implements TestOracle { protected E predicate; diff --git a/src/sqlancer/common/oracle/NoRECBase.java b/src/sqlancer/common/oracle/NoRECBase.java index 734e04a86..2ac0dbb43 100644 --- a/src/sqlancer/common/oracle/NoRECBase.java +++ b/src/sqlancer/common/oracle/NoRECBase.java @@ -6,7 +6,7 @@ import sqlancer.SQLGlobalState; import sqlancer.common.query.ExpectedErrors; -public abstract class NoRECBase> implements TestOracle { +public abstract class NoRECBase> implements TestOracle { protected final S state; protected final ExpectedErrors errors = new ExpectedErrors(); @@ -16,7 +16,7 @@ public abstract class NoRECBase> implements TestO protected String optimizedQueryString; protected String unoptimizedQueryString; - public NoRECBase(S state) { + protected NoRECBase(S state) { this.state = state; this.con = state.getConnection(); this.logger = state.getLogger(); diff --git a/src/sqlancer/common/oracle/NoRECOracle.java b/src/sqlancer/common/oracle/NoRECOracle.java new file mode 100644 index 000000000..caf3dff87 --- /dev/null +++ b/src/sqlancer/common/oracle/NoRECOracle.java @@ -0,0 +1,172 @@ +package sqlancer.common.oracle; + +import java.sql.SQLException; +import java.util.Objects; +import java.util.function.Function; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.Reproducer; +import sqlancer.SQLGlobalState; +import sqlancer.common.ast.newast.Expression; +import sqlancer.common.ast.newast.Join; +import sqlancer.common.ast.newast.Select; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; + +public class NoRECOracle, J extends Join, E extends Expression, S extends AbstractSchema, T extends AbstractTable, C extends AbstractTableColumn, G extends SQLGlobalState> + implements TestOracle { + + private final G state; + + private NoRECGenerator gen; + private final ExpectedErrors errors; + + private Reproducer reproducer; + private String lastQueryString; + + private static class NoRECReproducer> implements Reproducer { + private final Function optimizedQuery; + private final Function unoptimizedQuery; + + NoRECReproducer(Function optimizedQuery, Function unoptimizedQuery) { + this.optimizedQuery = optimizedQuery; + this.unoptimizedQuery = unoptimizedQuery; + } + + @Override + public boolean bugStillTriggers(G globalState) { + return !Objects.equals(optimizedQuery.apply(globalState), unoptimizedQuery.apply(globalState)); + } + } + + public NoRECOracle(G state, NoRECGenerator gen, ExpectedErrors expectedErrors) { + if (state == null || gen == null || expectedErrors == null) { + throw new IllegalArgumentException("Null variables used to initialize test oracle."); + } + this.state = state; + this.gen = gen; + this.errors = expectedErrors; + this.reproducer = null; + } + + @Override + public void check() throws SQLException { + reproducer = null; + S schema = state.getSchema(); + AbstractTables targetTables = TestOracleUtils.getRandomTableNonEmptyTables(schema); + gen = gen.setTablesAndColumns(targetTables); + + Z select = gen.generateSelect(); + select.setJoinClauses(gen.getRandomJoinClauses()); + select.setFromList(gen.getTableRefs()); + + E randomWhereCondition = gen.generateBooleanExpression(); + + boolean shouldUseAggregate = Randomly.getBoolean(); + String optimizedQueryString = gen.generateOptimizedQueryString(select, randomWhereCondition, + shouldUseAggregate); + lastQueryString = optimizedQueryString; + if (state.getOptions().logEachSelect()) { + state.getLogger().writeCurrent(optimizedQueryString); + } + + String unoptimizedQueryString = gen.generateUnoptimizedQueryString(select, randomWhereCondition); + if (state.getOptions().logEachSelect()) { + state.getLogger().writeCurrent(unoptimizedQueryString); + } + + int optimizedCount = shouldUseAggregate ? extractCounts(optimizedQueryString, errors, state) + : countRows(optimizedQueryString, errors, state); + int unoptimizedCount = extractCounts(unoptimizedQueryString, errors, state); + + if (optimizedCount == -1 || unoptimizedCount == -1) { + throw new IgnoreMeException(); + } + + if (unoptimizedCount != optimizedCount) { + Function optimizedQuery = state -> shouldUseAggregate + ? extractCounts(optimizedQueryString, errors, state) + : countRows(optimizedQueryString, errors, state); + + Function unoptimizedQuery = state -> extractCounts(unoptimizedQueryString, errors, state); + reproducer = new NoRECReproducer<>(optimizedQuery, unoptimizedQuery); + + String queryFormatString = "-- %s;\n-- count: %d"; + String firstQueryStringWithCount = String.format(queryFormatString, optimizedQueryString, optimizedCount); + String secondQueryStringWithCount = String.format(queryFormatString, unoptimizedQueryString, + unoptimizedCount); + state.getState().getLocalState() + .log(String.format("%s\n%s", firstQueryStringWithCount, secondQueryStringWithCount)); + String assertionMessage = String.format("the counts mismatch (%d and %d)!\n%s\n%s", optimizedCount, + unoptimizedCount, firstQueryStringWithCount, secondQueryStringWithCount); + throw new AssertionError(assertionMessage); + } + } + + @Override + public String getLastQueryString() { + return lastQueryString; + } + + @Override + public Reproducer getLastReproducer() { + return reproducer; + } + + private int countRows(String queryString, ExpectedErrors errors, SQLGlobalState state) { + SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors, false, false); + + int count = 0; + try (SQLancerResultSet rs = q.executeAndGet(state)) { + if (rs == null) { + return -1; + } else { + try { + while (rs.next()) { + count++; + } + } catch (SQLException e) { + count = -1; + } + } + } catch (Exception e) { + if (e instanceof IgnoreMeException) { + throw (IgnoreMeException) e; + } + throw new AssertionError(q.getQueryString(), e); + } + return count; + } + + private int extractCounts(String queryString, ExpectedErrors errors, SQLGlobalState state) { + SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors, false, false); + int count = 0; + try (SQLancerResultSet rs = q.executeAndGet(state)) { + if (rs == null) { + return -1; + } else { + try { + while (rs.next()) { + count += rs.getInt(1); + } + } catch (SQLException e) { + count = -1; + } + } + } catch (Exception e) { + if (e instanceof IgnoreMeException) { + throw (IgnoreMeException) e; + } + throw new AssertionError(q.getQueryString(), e); + } + return count; + } + +} diff --git a/src/sqlancer/common/oracle/PivotedQuerySynthesisBase.java b/src/sqlancer/common/oracle/PivotedQuerySynthesisBase.java index cd2580b76..cde8a2e4b 100644 --- a/src/sqlancer/common/oracle/PivotedQuerySynthesisBase.java +++ b/src/sqlancer/common/oracle/PivotedQuerySynthesisBase.java @@ -12,7 +12,7 @@ import sqlancer.common.schema.AbstractRowValue; public abstract class PivotedQuerySynthesisBase, R extends AbstractRowValue, E, C extends SQLancerDBConnection> - implements TestOracle { + implements TestOracle { protected final ExpectedErrors errors = new ExpectedErrors(); @@ -29,7 +29,7 @@ public abstract class PivotedQuerySynthesisBase, protected final S globalState; protected R pivotRow; - public PivotedQuerySynthesisBase(S globalState) { + protected PivotedQuerySynthesisBase(S globalState) { this.globalState = globalState; } diff --git a/src/sqlancer/common/oracle/TLPWhereOracle.java b/src/sqlancer/common/oracle/TLPWhereOracle.java new file mode 100644 index 000000000..14834a62f --- /dev/null +++ b/src/sqlancer/common/oracle/TLPWhereOracle.java @@ -0,0 +1,129 @@ +package sqlancer.common.oracle; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.Reproducer; +import sqlancer.SQLGlobalState; +import sqlancer.common.ast.newast.Expression; +import sqlancer.common.ast.newast.Join; +import sqlancer.common.ast.newast.Select; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; + +public class TLPWhereOracle, J extends Join, E extends Expression, S extends AbstractSchema, T extends AbstractTable, C extends AbstractTableColumn, G extends SQLGlobalState> + implements TestOracle { + + private final G state; + + private TLPWhereGenerator gen; + private final ExpectedErrors errors; + + private Reproducer reproducer; + private String generatedQueryString; + + private class TLPWhereReproducer implements Reproducer { + final String firstQueryString; + final String secondQueryString; + final String thirdQueryString; + final String originalQueryString; + final List resultSet; + final boolean orderBy; + + TLPWhereReproducer(String firstQueryString, String secondQueryString, String thirdQueryString, + String originalQueryString, List resultSet, boolean orderBy) { + this.firstQueryString = firstQueryString; + this.secondQueryString = secondQueryString; + this.thirdQueryString = thirdQueryString; + this.originalQueryString = originalQueryString; + this.resultSet = resultSet; + this.orderBy = orderBy; + } + + @Override + public boolean bugStillTriggers(G globalState) { + try { + List combinedString1 = new ArrayList<>(); + List secondResultSet1 = ComparatorHelper.getCombinedResultSet(firstQueryString, + secondQueryString, thirdQueryString, combinedString1, !orderBy, globalState, errors); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet1, originalQueryString, + combinedString1, globalState); + } catch (AssertionError triggeredError) { + return true; + } catch (SQLException ignored) { + } + return false; + } + } + + public TLPWhereOracle(G state, TLPWhereGenerator gen, ExpectedErrors expectedErrors) { + if (state == null || gen == null || expectedErrors == null) { + throw new IllegalArgumentException("Null variables used to initialize test oracle."); + } + this.state = state; + this.gen = gen; + this.errors = expectedErrors; + } + + @Override + public void check() throws SQLException { + reproducer = null; + S s = state.getSchema(); + AbstractTables targetTables = TestOracleUtils.getRandomTableNonEmptyTables(s); + gen = gen.setTablesAndColumns(targetTables); + + Select select = gen.generateSelect(); + + boolean shouldCreateDummy = true; + select.setFetchColumns(gen.generateFetchColumns(shouldCreateDummy)); + select.setJoinClauses(gen.getRandomJoinClauses()); + select.setFromList(gen.getTableRefs()); + select.setWhereClause(null); + + String originalQueryString = select.asString(); + generatedQueryString = originalQueryString; + List firstResultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, + state); + + boolean orderBy = Randomly.getBooleanWithSmallProbability(); + if (orderBy) { + select.setOrderByClauses(gen.generateOrderBys()); + } + + TestOracleUtils.PredicateVariants predicates = TestOracleUtils.initializeTernaryPredicateVariants(gen, + gen.generateBooleanExpression()); + select.setWhereClause(predicates.predicate); + String firstQueryString = select.asString(); + select.setWhereClause(predicates.negatedPredicate); + String secondQueryString = select.asString(); + select.setWhereClause(predicates.isNullPredicate); + String thirdQueryString = select.asString(); + + List combinedString = new ArrayList<>(); + List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, + thirdQueryString, combinedString, !orderBy, state, errors); + + ComparatorHelper.assumeResultSetsAreEqual(firstResultSet, secondResultSet, originalQueryString, combinedString, + state); + + reproducer = new TLPWhereReproducer(firstQueryString, secondQueryString, thirdQueryString, originalQueryString, + firstResultSet, orderBy); + } + + @Override + public Reproducer getLastReproducer() { + return reproducer; + } + + @Override + public String getLastQueryString() { + return generatedQueryString; + } +} diff --git a/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java b/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java index 3b5d87814..b4d6add4b 100644 --- a/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java +++ b/src/sqlancer/common/oracle/TernaryLogicPartitioningOracleBase.java @@ -14,7 +14,7 @@ * @param * the global state type */ -public abstract class TernaryLogicPartitioningOracleBase> implements TestOracle { +public abstract class TernaryLogicPartitioningOracleBase> implements TestOracle { protected E predicate; protected E negatedPredicate; diff --git a/src/sqlancer/common/oracle/TestOracle.java b/src/sqlancer/common/oracle/TestOracle.java index 737adb692..0ef993b89 100644 --- a/src/sqlancer/common/oracle/TestOracle.java +++ b/src/sqlancer/common/oracle/TestOracle.java @@ -1,7 +1,17 @@ package sqlancer.common.oracle; -public interface TestOracle { +import sqlancer.GlobalState; +import sqlancer.Reproducer; + +public interface TestOracle> { void check() throws Exception; + default Reproducer getLastReproducer() { + return null; + } + + default String getLastQueryString() { + throw new AssertionError("Not supported!"); + } } diff --git a/src/sqlancer/common/oracle/TestOracleUtils.java b/src/sqlancer/common/oracle/TestOracleUtils.java new file mode 100644 index 000000000..bab2e26c9 --- /dev/null +++ b/src/sqlancer/common/oracle/TestOracleUtils.java @@ -0,0 +1,55 @@ +package sqlancer.common.oracle; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Expression; +import sqlancer.common.gen.PartitionGenerator; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; + +public final class TestOracleUtils { + + private TestOracleUtils() { + } + + public static final class PredicateVariants, C extends AbstractTableColumn> { + public E predicate; + public E negatedPredicate; + public E isNullPredicate; + + PredicateVariants(E predicate, E negatedPredicate, E isNullPredicate) { + this.predicate = predicate; + this.negatedPredicate = negatedPredicate; + this.isNullPredicate = isNullPredicate; + } + } + + public static , C extends AbstractTableColumn> AbstractTables getRandomTableNonEmptyTables( + AbstractSchema schema) { + if (schema.getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + return new AbstractTables<>(Randomly.nonEmptySubset(schema.getDatabaseTables())); + } + + public static , T extends AbstractTable, C extends AbstractTableColumn> PredicateVariants initializeTernaryPredicateVariants( + PartitionGenerator gen, E predicate) { + if (gen == null) { + throw new IllegalStateException(); + } + if (predicate == null) { + throw new IllegalStateException(); + } + E negatedPredicate = gen.negatePredicate(predicate); + if (negatedPredicate == null) { + throw new IllegalStateException(); + } + E isNullPredicate = gen.isNull(predicate); + if (isNullPredicate == null) { + throw new IllegalStateException(); + } + return new PredicateVariants<>(predicate, negatedPredicate, isNullPredicate); + } +} diff --git a/src/sqlancer/common/query/ExpectedErrors.java b/src/sqlancer/common/query/ExpectedErrors.java index 20a7d97c8..eb281efbe 100644 --- a/src/sqlancer/common/query/ExpectedErrors.java +++ b/src/sqlancer/common/query/ExpectedErrors.java @@ -1,17 +1,32 @@ package sqlancer.common.query; +import java.io.Serializable; +import java.util.Arrays; import java.util.Collection; import java.util.HashSet; import java.util.Set; +import java.util.regex.Pattern; /** * This class represents the errors that executing a statement might result in. For example, an INSERT statement might * result in an error "UNIQUE constraint violated" when it attempts to insert a duplicate value in a column declared as * UNIQUE. */ -public class ExpectedErrors { +public class ExpectedErrors implements Serializable { + private static final long serialVersionUID = 1L; - private final Set errors = new HashSet<>(); + private final Set errors; + private final Set regexes; + + public ExpectedErrors() { + this.errors = new HashSet<>(); + this.regexes = new HashSet<>(); + } + + public ExpectedErrors(Collection errors, Collection regexErrors) { + this.errors = new HashSet<>(errors); + this.regexes = new HashSet<>(regexErrors); + } public ExpectedErrors add(String error) { if (error == null) { @@ -21,6 +36,53 @@ public ExpectedErrors add(String error) { return this; } + public ExpectedErrors addRegex(Pattern errorPattern) { + if (errorPattern == null) { + throw new IllegalArgumentException(); + } + regexes.add(errorPattern); + return this; + } + + public ExpectedErrors addRegexString(String errorPattern) { + if (errorPattern == null) { + throw new IllegalArgumentException(); + } + regexes.add(Pattern.compile(errorPattern)); + return this; + } + + public ExpectedErrors addAll(Collection list) { + if (list == null) { + throw new IllegalArgumentException(); + } + errors.addAll(list); + return this; + } + + public ExpectedErrors addAllRegexes(Collection list) { + if (list == null) { + throw new IllegalArgumentException(); + } + regexes.addAll(list); + return this; + } + + public ExpectedErrors addAllRegexStrings(Collection list) { + for (String error : list) { + regexes.add(Pattern.compile(error)); + } + return this; + } + + public static ExpectedErrors from(String... errors) { + return newErrors().with(errors).build(); + } + + public static ExpectedErrorsBuilder newErrors() { + return new ExpectedErrorsBuilder(); + } + /** * Checks whether the error message (e.g., returned by the DBMS under test) contains any of the added error * messages. @@ -34,25 +96,54 @@ public boolean errorIsExpected(String error) { if (error == null) { throw new IllegalArgumentException(); } - for (String s : errors) { + for (String s : this.errors) { if (error.contains(s)) { return true; } } + for (Pattern p : this.regexes) { + if (p.matcher(error).find()) { + return true; + } + } return false; } - public ExpectedErrors addAll(Collection list) { - errors.addAll(list); - return this; - } + public static class ExpectedErrorsBuilder { + private final Set errors = new HashSet<>(); + private final Set regexes = new HashSet<>(); - public static ExpectedErrors from(String... errors) { - ExpectedErrors expectedErrors = new ExpectedErrors(); - for (String error : errors) { - expectedErrors.add(error); + public ExpectedErrorsBuilder with(String... list) { + errors.addAll(Arrays.asList(list)); + return this; + } + + public ExpectedErrorsBuilder with(Collection list) { + return with(list.toArray(new String[0])); + } + + public ExpectedErrorsBuilder withRegex(Pattern... list) { + regexes.addAll(Arrays.asList(list)); + return this; + } + + public ExpectedErrorsBuilder withRegex(Collection list) { + return withRegex(list.toArray(new Pattern[0])); } - return expectedErrors; - } + public ExpectedErrorsBuilder withRegexString(String... list) { + for (String error : list) { + regexes.add(Pattern.compile(error)); + } + return this; + } + + public ExpectedErrorsBuilder withRegexString(Collection list) { + return withRegexString(list.toArray(new String[0])); + } + + public ExpectedErrors build() { + return new ExpectedErrors(errors, regexes); + } + } } diff --git a/src/sqlancer/common/query/Query.java b/src/sqlancer/common/query/Query.java index 44efc0084..ca90619c7 100644 --- a/src/sqlancer/common/query/Query.java +++ b/src/sqlancer/common/query/Query.java @@ -5,6 +5,7 @@ import sqlancer.common.log.Loggable; public abstract class Query implements Loggable { + private static final long serialVersionUID = 1L; /** * Gets the query string, which is guaranteed to be terminated with a semicolon. diff --git a/src/sqlancer/common/query/SQLQueryAdapter.java b/src/sqlancer/common/query/SQLQueryAdapter.java index 01fcc8539..db8a2c66d 100644 --- a/src/sqlancer/common/query/SQLQueryAdapter.java +++ b/src/sqlancer/common/query/SQLQueryAdapter.java @@ -1,5 +1,6 @@ package sqlancer.common.query; +import java.io.Serializable; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; @@ -9,7 +10,8 @@ import sqlancer.Main; import sqlancer.SQLConnection; -public class SQLQueryAdapter extends Query { +public class SQLQueryAdapter extends Query implements Serializable { + private static final long serialVersionUID = 1L; private final String query; private final ExpectedErrors expectedErrors; @@ -24,11 +26,24 @@ public SQLQueryAdapter(String query, boolean couldAffectSchema) { } public SQLQueryAdapter(String query, ExpectedErrors expectedErrors) { - this(query, expectedErrors, false); + this(query, expectedErrors, guessAffectSchemaFromQuery(query)); + } + + private static boolean guessAffectSchemaFromQuery(String query) { + return query.contains("CREATE TABLE") && !query.startsWith("EXPLAIN"); } public SQLQueryAdapter(String query, ExpectedErrors expectedErrors, boolean couldAffectSchema) { - this.query = canonicalizeString(query); + this(query, expectedErrors, couldAffectSchema, true); + } + + public SQLQueryAdapter(String query, ExpectedErrors expectedErrors, boolean couldAffectSchema, + boolean canonicalizeString) { + if (canonicalizeString) { + this.query = canonicalizeString(query); + } else { + this.query = query; + } this.expectedErrors = expectedErrors; this.couldAffectSchema = couldAffectSchema; checkQueryString(); @@ -46,7 +61,7 @@ private String canonicalizeString(String s) { } private void checkQueryString() { - if (query.contains("CREATE TABLE") && !couldAffectSchema) { + if (!couldAffectSchema && guessAffectSchemaFromQuery(query)) { throw new AssertionError("CREATE TABLE statements should set couldAffectSchema to true"); } } @@ -68,17 +83,55 @@ public String getUnterminatedQueryString() { return result; } + /** + * This method is used to mostly oracles, which need to report exceptions. We set the reportException parameter to + * true by default meaning that exceptions are reported. + * + * @param globalState + * @param fills + * + * @return whether the query was executed successfully + * + * @param + * + * @throws SQLException + */ @Override public > boolean execute(G globalState, String... fills) throws SQLException { + return execute(globalState, true, fills); + } + + /** + * This method is used to DQE oracles, DQE does not check exception separately, while other testing methods may + * need. We use reportException to control this behavior. For a specific DBMS used DQE oracle, we call this method + * and pass a boolean value of false as an argument. + * + * @param globalState + * @param reportException + * @param fills + * + * @return whether the query was executed successfully + * + * @param + * + * @throws SQLException + */ + public > boolean execute(G globalState, boolean reportException, + String... fills) throws SQLException { + return internalExecute(globalState.getConnection(), reportException, fills); + } + + protected > boolean internalExecute(SQLConnection connection, + boolean reportException, String... fills) throws SQLException { Statement s; if (fills.length > 0) { - s = globalState.getConnection().prepareStatement(fills[0]); + s = connection.prepareStatement(fills[0]); for (int i = 1; i < fills.length; i++) { ((PreparedStatement) s).setString(i, fills[i]); } } else { - s = globalState.getConnection().createStatement(); + s = connection.createStatement(); } try { if (fills.length > 0) { @@ -90,8 +143,12 @@ public String getUnterminatedQueryString() { return true; } catch (Exception e) { Main.nrUnsuccessfulActions.addAndGet(1); - checkException(e); + if (reportException) { + checkException(e); + } return false; + } finally { + s.close(); } } @@ -112,14 +169,24 @@ public void checkException(Exception e) throws AssertionError { @Override public > SQLancerResultSet executeAndGet(G globalState, String... fills) throws SQLException { + return executeAndGet(globalState, true, fills); + } + + public > SQLancerResultSet executeAndGet(G globalState, + boolean reportException, String... fills) throws SQLException { + return internalExecuteAndGet(globalState.getConnection(), reportException, fills); + } + + protected > SQLancerResultSet internalExecuteAndGet( + SQLConnection connection, boolean reportException, String... fills) throws SQLException { Statement s; if (fills.length > 0) { - s = globalState.getConnection().prepareStatement(fills[0]); + s = connection.prepareStatement(fills[0]); for (int i = 1; i < fills.length; i++) { ((PreparedStatement) s).setString(i, fills[i]); } } else { - s = globalState.getConnection().createStatement(); + s = connection.createStatement(); } ResultSet result; try { @@ -136,9 +203,11 @@ public void checkException(Exception e) throws AssertionError { } catch (Exception e) { s.close(); Main.nrUnsuccessfulActions.addAndGet(1); - checkException(e); + if (reportException) { + checkException(e); + } + return null; } - return null; } @Override diff --git a/src/sqlancer/common/query/SQLQueryError.java b/src/sqlancer/common/query/SQLQueryError.java new file mode 100644 index 000000000..41f604930 --- /dev/null +++ b/src/sqlancer/common/query/SQLQueryError.java @@ -0,0 +1,116 @@ +package sqlancer.common.query; + +import java.util.Objects; + +public class SQLQueryError implements Comparable { + + public enum ErrorLevel { + WARNING, ERROR + } + + private ErrorLevel level; + private int code; + private String message; + + public void setLevel(ErrorLevel level) { + this.level = level; + } + + public void setCode(int code) { + this.code = code; + } + + public void setMessage(String message) { + this.message = message; + } + + public ErrorLevel getLevel() { + return level; + } + + public int getCode() { + return code; + } + + public String getMessage() { + return message; + } + + public boolean hasSameLevel(SQLQueryError that) { + if (level == null) { + return that.getLevel() == null; + } else { + return level.equals(that.getLevel()); + } + } + + public boolean hasSameCodeAndMessage(SQLQueryError that) { + if (code != that.getCode()) { + return false; + } + if (message == null) { + return that.getMessage() == null; + } else { + return message.equals(that.getMessage()); + } + } + + @Override + public boolean equals(Object that) { + if (that == null) { + return false; + } + if (that instanceof SQLQueryError) { + SQLQueryError thatError = (SQLQueryError) that; + return hasSameLevel(thatError) && hasSameCodeAndMessage(thatError); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(level, code, message); + } + + @Override + public String toString() { + return String.format("Level: %s; Code: %d; Message: %s.", level, code, message); + } + + @Override + public int compareTo(SQLQueryError that) { + if (code < that.getCode()) { + return -1; + } else if (code > that.getCode()) { + return 1; + } + + if (level == null && that.getLevel() != null) { + return -1; + } else { + if (that.getLevel() == null) { + return 1; + } else { + int res = level.compareTo(that.getLevel()); + if (res != 0) { + return res; + } + } + } + + if (message == null && that.getMessage() != null) { + return -1; + } else { + if (that.getMessage() == null) { + return 1; + } else { + int res = message.compareTo(that.getMessage()); + if (res != 0) { + return res; + } + } + } + + return 0; + } +} diff --git a/src/sqlancer/common/query/SQLQueryResultCheckAdapter.java b/src/sqlancer/common/query/SQLQueryResultCheckAdapter.java index 8f6a01a4b..1eaae2424 100644 --- a/src/sqlancer/common/query/SQLQueryResultCheckAdapter.java +++ b/src/sqlancer/common/query/SQLQueryResultCheckAdapter.java @@ -9,6 +9,7 @@ import sqlancer.SQLConnection; public class SQLQueryResultCheckAdapter extends SQLQueryAdapter { + private static final long serialVersionUID = 1L; private final Consumer rsChecker; diff --git a/src/sqlancer/common/query/SQLancerResultSet.java b/src/sqlancer/common/query/SQLancerResultSet.java index f83cdd3a5..912a6a882 100644 --- a/src/sqlancer/common/query/SQLancerResultSet.java +++ b/src/sqlancer/common/query/SQLancerResultSet.java @@ -35,7 +35,19 @@ public int getInt(int i) throws SQLException { } public String getString(int i) throws SQLException { - return rs.getString(i); + try { + return rs.getString(i); + } catch (NumberFormatException e) { + throw new SQLException(e); + } + } + + public String getString(String colName) throws SQLException { + return rs.getString(colName); + } + + public int getInt(String colName) throws SQLException { + return rs.getInt(colName); } public boolean isClosed() throws SQLException { @@ -46,6 +58,10 @@ public long getLong(int i) throws SQLException { return rs.getLong(i); } + public String getType(int i) throws SQLException { + return rs.getMetaData().getColumnTypeName(i); + } + public void registerEpilogue(Runnable runnableEpilogue) { this.runnableEpilogue = runnableEpilogue; } diff --git a/src/sqlancer/common/schema/AbstractRowValue.java b/src/sqlancer/common/schema/AbstractRowValue.java index c7f933c9b..2a979514d 100644 --- a/src/sqlancer/common/schema/AbstractRowValue.java +++ b/src/sqlancer/common/schema/AbstractRowValue.java @@ -63,7 +63,7 @@ public String asStringGroupedByTables() { sb.append("\n"); } AbstractTable t = tableList.get(j); - sb.append("-- " + t.getName() + "\n"); + sb.append("-- ").append(t.getName()).append("\n"); List columnsForTable = columnList.stream().filter(c -> c.getTable().equals(t)) .collect(Collectors.toList()); for (int i = 0; i < columnsForTable.size(); i++) { diff --git a/src/sqlancer/common/schema/AbstractSchema.java b/src/sqlancer/common/schema/AbstractSchema.java index 84a854f72..053c9540a 100644 --- a/src/sqlancer/common/schema/AbstractSchema.java +++ b/src/sqlancer/common/schema/AbstractSchema.java @@ -45,7 +45,7 @@ public A getRandomTable(Predicate predicate) { } public A getRandomTableOrBailout(Function f) { - List relevantTables = databaseTables.stream().filter(t -> f.apply(t)).collect(Collectors.toList()); + List relevantTables = databaseTables.stream().filter(f::apply).collect(Collectors.toList()); if (relevantTables.isEmpty()) { throw new IgnoreMeException(); } @@ -126,6 +126,10 @@ public String getFreeTableName() { } + public static boolean matchesViewName(String relationName) { + return relationName.startsWith("v"); + } + public String getFreeViewName() { int i = 0; if (Randomly.getBooleanWithRatherLowProbability()) { diff --git a/src/sqlancer/common/schema/AbstractTable.java b/src/sqlancer/common/schema/AbstractTable.java index c40f1a197..58154681c 100644 --- a/src/sqlancer/common/schema/AbstractTable.java +++ b/src/sqlancer/common/schema/AbstractTable.java @@ -2,9 +2,11 @@ import java.util.Collections; import java.util.List; +import java.util.function.Predicate; import java.util.stream.Collectors; import sqlancer.GlobalState; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; public abstract class AbstractTable, I extends TableIndex, G extends GlobalState> @@ -17,7 +19,7 @@ public abstract class AbstractTable, I exten private final boolean isView; protected long rowCount = NO_ROW_COUNT_AVAILABLE; - public AbstractTable(String name, List columns, List indexes, boolean isView) { + protected AbstractTable(String name, List columns, List indexes, boolean isView) { this.name = name; this.indexes = indexes; this.isView = isView; @@ -39,7 +41,7 @@ public String toString() { sb.append(getName()); sb.append("\n"); for (C c : columns) { - sb.append("\t" + c + "\n"); + sb.append("\t").append(c).append("\n"); } return sb.toString(); } @@ -60,6 +62,15 @@ public C getRandomColumn() { return Randomly.fromList(columns); } + public C getRandomColumnOrBailout(Predicate predicate) { + List relevantColumns = columns.stream().filter(predicate).collect(Collectors.toList()); + if (relevantColumns.isEmpty()) { + throw new IgnoreMeException(); + } + + return Randomly.fromList(relevantColumns); + } + public boolean hasIndexes() { return !indexes.isEmpty(); } @@ -72,6 +83,10 @@ public List getRandomNonEmptyColumnSubset() { return Randomly.nonEmptySubset(getColumns()); } + public List getRandomNonEmptyColumnSubsetFilter(Predicate predicate) { + return Randomly.nonEmptySubset(getColumns().stream().filter(predicate).collect(Collectors.toList())); + } + public List getRandomNonEmptyColumnSubset(int size) { return Randomly.nonEmptySubset(getColumns(), size); } @@ -80,6 +95,10 @@ public boolean isView() { return isView; } + public boolean hasPrimaryKey() { + return columns.stream().anyMatch(c -> c.isPrimaryKey()); + } + public String getFreeColumnName() { int i = 0; if (Randomly.getBooleanWithRatherLowProbability()) { diff --git a/src/sqlancer/common/schema/AbstractTableColumn.java b/src/sqlancer/common/schema/AbstractTableColumn.java index e519bea97..a2f5fb1b3 100644 --- a/src/sqlancer/common/schema/AbstractTableColumn.java +++ b/src/sqlancer/common/schema/AbstractTableColumn.java @@ -12,6 +12,10 @@ public AbstractTableColumn(String name, T table, U type) { this.type = type; } + public boolean isPrimaryKey() { + return false; + } + public String getName() { return name; } diff --git a/src/sqlancer/common/schema/AbstractTables.java b/src/sqlancer/common/schema/AbstractTables.java index 2afff82ab..ff20fde20 100644 --- a/src/sqlancer/common/schema/AbstractTables.java +++ b/src/sqlancer/common/schema/AbstractTables.java @@ -34,4 +34,27 @@ public String columnNamesAsString(Function function) { return getColumns().stream().map(function).collect(Collectors.joining(", ")); } + public void addTable(T table) { + if (!this.tables.contains(table)) { + this.tables.add(table); + columns.addAll(table.getColumns()); + } + } + + public void removeTable(T table) { + if (this.tables.contains(table)) { + this.tables.remove(table); + for (C c : table.getColumns()) { + columns.remove(c); + } + } + } + + public boolean isContained(T table) { + return this.tables.contains(table); + } + + public int getSize() { + return this.tables.size(); + } } diff --git a/src/sqlancer/cosmos/CosmosProvider.java b/src/sqlancer/cosmos/CosmosProvider.java deleted file mode 100644 index 424625b4f..000000000 --- a/src/sqlancer/cosmos/CosmosProvider.java +++ /dev/null @@ -1,77 +0,0 @@ -package sqlancer.cosmos; - -import com.google.auto.service.AutoService; -import com.mongodb.ConnectionString; -import com.mongodb.MongoClientSettings; -import com.mongodb.client.MongoClient; -import com.mongodb.client.MongoClients; -import com.mongodb.client.MongoDatabase; - -import sqlancer.DatabaseProvider; -import sqlancer.IgnoreMeException; -import sqlancer.ProviderAdapter; -import sqlancer.Randomly; -import sqlancer.StatementExecutor; -import sqlancer.common.log.LoggableFactory; -import sqlancer.mongodb.MongoDBConnection; -import sqlancer.mongodb.MongoDBLoggableFactory; -import sqlancer.mongodb.MongoDBOptions; -import sqlancer.mongodb.MongoDBQueryAdapter; -import sqlancer.mongodb.gen.MongoDBTableGenerator; - -@AutoService(DatabaseProvider.class) -public class CosmosProvider extends - ProviderAdapter { - - public CosmosProvider() { - super(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState.class, MongoDBOptions.class); - } - - @Override - public void generateDatabase(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState globalState) throws Exception { - for (int i = 0; i < Randomly.fromOptions(4, 5, 6); i++) { - boolean success; - do { - MongoDBQueryAdapter query = new MongoDBTableGenerator(globalState).getQuery(globalState); - success = globalState.executeStatement(query); - } while (!success); - } - StatementExecutor se = new StatementExecutor<>( - globalState, sqlancer.mongodb.MongoDBProvider.Action.values(), - sqlancer.mongodb.MongoDBProvider::mapActions, (q) -> { - if (globalState.getSchema().getDatabaseTables().isEmpty()) { - throw new IgnoreMeException(); - } - }); - se.executeStatements(); - } - - @Override - public MongoDBConnection createDatabase(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState globalState) - throws Exception { - String connectionString = ""; - if (connectionString.equals("")) { - throw new AssertionError("Please set connection string for cosmos database, located in CosmosProvider"); - } - MongoClientSettings settings = MongoClientSettings.builder() - .applyConnectionString(new ConnectionString(connectionString)).build(); - MongoClient mongoClient = MongoClients.create(settings); - MongoDatabase database = mongoClient.getDatabase(globalState.getDatabaseName()); - database.drop(); - return new MongoDBConnection(mongoClient, database); - } - - @Override - public String getDBMSName() { - return "cosmos"; - } - - @Override - public LoggableFactory getLoggableFactory() { - return new MongoDBLoggableFactory(); - } - - @Override - protected void checkViewsAreValid(sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState globalState) { - } -} diff --git a/src/sqlancer/databend/DatabendBugs.java b/src/sqlancer/databend/DatabendBugs.java new file mode 100644 index 000000000..a058ccc32 --- /dev/null +++ b/src/sqlancer/databend/DatabendBugs.java @@ -0,0 +1,26 @@ +package sqlancer.databend; + +public final class DatabendBugs { + + public static boolean bug9018; // https://github.com/datafuselabs/databend/issues/9018 + public static boolean bug9162; // https://github.com/datafuselabs/databend/issues/9162 + public static boolean bug9163; // https://github.com/datafuselabs/databend/issues/9163 + public static boolean bug9164 = true; // https://github.com/datafuselabs/databend/issues/9164 + public static boolean bug9196 = true; // https://github.com/datafuselabs/databend/issues/9196 + public static boolean bug9232 = true; // https://github.com/datafuselabs/databend/issues/9232 + public static boolean bug9224 = true; // https://github.com/datafuselabs/databend/issues/9224 + public static boolean bug9226 = true; // https://github.com/datafuselabs/databend/issues/9226 + public static boolean bug9234 = true; // https://github.com/datafuselabs/databend/issues/9234 + public static boolean bug9235 = true; // https://github.com/datafuselabs/databend/issues/9235 + public static boolean bug9236 = true; // https://github.com/datafuselabs/databend/issues/9236 + public static boolean bug9264 = true; // https://github.com/datafuselabs/databend/issues/9264 + public static boolean bug9806 = true; // https://github.com/datafuselabs/databend/issues/9806 + public static boolean bug15568 = true; // https://github.com/datafuselabs/databend/issues/15568 + public static boolean bug15569 = true; // https://github.com/datafuselabs/databend/issues/15569 + public static boolean bug15570 = true; // https://github.com/datafuselabs/databend/issues/15570 + public static boolean bug15572 = true; // https://github.com/datafuselabs/databend/issues/15572 + public static boolean bug19738 = true; // https://github.com/databendlabs/databend/issues/19738 + + private DatabendBugs() { + } +} diff --git a/src/sqlancer/databend/DatabendErrors.java b/src/sqlancer/databend/DatabendErrors.java index 60ea7ca9c..746a4e848 100644 --- a/src/sqlancer/databend/DatabendErrors.java +++ b/src/sqlancer/databend/DatabendErrors.java @@ -1,5 +1,8 @@ package sqlancer.databend; +import java.util.ArrayList; +import java.util.List; + import sqlancer.common.query.ExpectedErrors; public final class DatabendErrors { @@ -7,110 +10,92 @@ public final class DatabendErrors { private DatabendErrors() { } - public static void addExpressionErrors(ExpectedErrors errors) { - errors.add("with non-constant precision is not supported"); - errors.add("Like pattern must not end with escape character"); - errors.add("Could not convert string"); - errors.add("ORDER term out of range - should be between "); - errors.add("You might need to add explicit type casts."); - errors.add("can't be cast because the value is out of range for the destination type"); - errors.add("field value out of range"); - errors.add("Unimplemented type for cast"); - - errors.add("Escape string must be empty or one character."); - errors.add("Type mismatch when combining rows"); // BETWEEN - - errors.add("invalid UTF-8"); // TODO - errors.add("String value is not valid UTF8"); - - errors.add("Invalid TypeId "); // TODO - - errors.add("GROUP BY clause cannot contain aggregates!"); // investigate - - addRegexErrors(errors); - - addFunctionErrors(errors); - - errors.add("Overflow in multiplication"); - errors.add("Out of Range"); - errors.add("Date out of range"); + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("Division by zero"); + errors.add("divided by zero"); + errors.add("/ by zero"); + errors.add("ORDER BY position"); + errors.add("GROUP BY position"); + errors.add("no overload satisfies `not(Float64 NULL)`"); // TODO databend不允许出现not(float),而a/b为float + errors.add("no overload satisfies `not(Float64)`"); + errors.add("number overflowed while evaluating function"); // 表达式数值溢出 + errors.add("Unable to get field named"); + errors.add("no overload satisfies `and_filters"); + if (DatabendBugs.bug9162) { + errors.add("downcast column error"); + } + if (DatabendBugs.bug9018) { + errors.add("index out of bounds"); + } + if (DatabendBugs.bug9163) { + errors.add("validity must be equal to the array's length"); + } + if (DatabendBugs.bug9224) { + errors.add("Can't cast column from nullable data into non-nullable type"); + } + if (DatabendBugs.bug9234) { + errors.add("called `Option::unwrap()` on a `None` value"); + } + if (DatabendBugs.bug9264) { + errors.add("assertion failed: offset + length <= self.length"); + } + if (DatabendBugs.bug9806) { + errors.add("segment pruning failure"); + } + if (DatabendBugs.bug15568) { + errors.add("Decimal overflow at line : 723 while evaluating function `to_decimal"); + } + if (DatabendBugs.bug19738) { + errors.add("UnwindError"); + errors.add("unable to cast `NULL`"); + } + + /* + * TODO column为not null 时,注意default不能为null DROP DATABASE IF EXISTS databend2; CREATE DATABASE databend2; USE + * databend2; CREATE TABLE t0(c0VARCHAR VARCHAR NULL, c1VARCHAR VARCHAR NULL, c2FLOAT FLOAT NOT NULL + * DEFAULT(NULL)); CREATE TABLE t1(c0INT BIGINT NULL); INSERT INTO t0(c1varchar, c0varchar) VALUES + * ('067596','19'), ('', '87'); + */ + errors.add("Can't cast column from null into non-nullable type"); + + return errors; + } - // collate - errors.add("Cannot combine types with different collation!"); - errors.add("collations are only supported for type varchar"); + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } - // // https://github.com/cwida/Databend/issues/532 - errors.add("Not implemented type: DATE"); - errors.add("Not implemented type: TIMESTAMP"); - errors.add("Like pattern must not end with escape character!"); // LIKE + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); - errors.add("does not have a column named \"rowid\""); // TODO: this can be removed if we can query whether a - // table supports rowids + errors.add("Division by zero"); + errors.add("/ by zero"); + errors.add("violates not-null constraint"); + errors.add("number overflowed while evaluating function `"); // 不能在int16类型column上插入int64的数据 - errors.add("does not have a column named"); // TODO: this only happens for views whose underlying table has a - // removed column - errors.add("Contents of view were altered: types don't match!"); - errors.add("Not implemented: ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); - // TODO Databend待修复的bug(union schema error mismatch) - errors.add("unexpected end of file (failed to fill whole buffer)"); + return errors; } - private static void addRegexErrors(ExpectedErrors errors) { - errors.add("missing ]"); - errors.add("missing )"); - errors.add("invalid escape sequence"); - errors.add("no argument for repetition operator: "); - errors.add("bad repetition operator"); - errors.add("trailing \\"); - errors.add("invalid perl operator"); - errors.add("invalid character class range"); - errors.add("width is not integer"); + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); } - private static void addFunctionErrors(ExpectedErrors errors) { - errors.add("SUBSTRING cannot handle negative lengths"); - errors.add("is undefined outside [-1,1]"); // ACOS etc - errors.add("invalid type specifier"); // PRINTF - errors.add("argument index out of range"); // PRINTF - errors.add("invalid format string"); // PRINTF - errors.add("number is too big"); // PRINTF - errors.add("Like pattern must not end with escape character!"); // LIKE - errors.add("Could not choose a best candidate function for the function call \"date_part"); // date_part - errors.add("extract specifier"); // date_part - errors.add("not recognized"); // date_part - errors.add("not supported"); // date_part - errors.add("Failed to cast"); - errors.add("Conversion Error"); - errors.add("Could not cast value"); - errors.add("Insufficient padding in RPAD"); // RPAD - errors.add("Could not choose a best candidate function for the function call"); // monthname - errors.add("expected a numeric precision field"); // ROUND - errors.add("with non-constant precision is not supported"); // ROUND - } + public static List getGroupByErrors() { + ArrayList errors = new ArrayList<>(); - public static void addInsertErrors(ExpectedErrors errors) { - addRegexErrors(errors); - addFunctionErrors(errors); - - errors.add("NOT NULL constraint failed"); - errors.add("PRIMARY KEY or UNIQUE constraint violated"); - errors.add("duplicate key"); - errors.add("can't be cast because the value is out of range for the destination type"); - errors.add("Could not convert string"); - errors.add("Unimplemented type for cast"); - errors.add("field value out of range"); - errors.add("CHECK constraint failed"); - errors.add("Cannot explicitly insert values into rowid column"); // TODO: don't insert into rowid - errors.add(" Column with name rowid does not exist!"); // currently, there doesn't seem to way to determine if - // the table has a primary key - errors.add("Could not cast value"); - errors.add("create unique index, table contains duplicate data"); - errors.add("Failed to cast"); + errors.add("Division by zero"); + errors.add("/ by zero"); + errors.add("Can't cast column from null into non-nullable type"); + errors.add("GROUP BY position"); + errors.add("GROUP BY items can't contain aggregate functions or window functions"); + + return errors; } public static void addGroupByErrors(ExpectedErrors errors) { - errors.add("must appear in the GROUP BY clause or be used in an aggregate function"); - errors.add("GROUP BY term out of range"); + errors.addAll(getGroupByErrors()); } } diff --git a/src/sqlancer/databend/DatabendExpectedValueVisitor.java b/src/sqlancer/databend/DatabendExpectedValueVisitor.java new file mode 100644 index 000000000..bfa6208a3 --- /dev/null +++ b/src/sqlancer/databend/DatabendExpectedValueVisitor.java @@ -0,0 +1,152 @@ +package sqlancer.databend; + +import java.util.List; + +import sqlancer.databend.ast.DatabendAlias; +import sqlancer.databend.ast.DatabendBetweenOperation; +import sqlancer.databend.ast.DatabendBinaryOperation; +import sqlancer.databend.ast.DatabendColumnReference; +import sqlancer.databend.ast.DatabendConstant; +import sqlancer.databend.ast.DatabendExpression; +import sqlancer.databend.ast.DatabendFunctionOperation; +import sqlancer.databend.ast.DatabendInOperation; +import sqlancer.databend.ast.DatabendJoin; +import sqlancer.databend.ast.DatabendOrderByTerm; +import sqlancer.databend.ast.DatabendPostFixText; +import sqlancer.databend.ast.DatabendSelect; +import sqlancer.databend.ast.DatabendTableReference; +import sqlancer.databend.ast.DatabendUnaryPostfixOperation; +import sqlancer.databend.ast.DatabendUnaryPrefixOperation; + +public class DatabendExpectedValueVisitor { + + protected final StringBuilder sb = new StringBuilder(); + + private void print(DatabendExpression expr) { + sb.append(DatabendToStringVisitor.asString(expr)); + sb.append(" -- "); + sb.append(expr.getExpectedValue()); + sb.append("\n"); + } + + public void visit(DatabendExpression expr) { + assert expr != null; + if (expr instanceof DatabendColumnReference) { + visit((DatabendColumnReference) expr); + } else if (expr instanceof DatabendUnaryPostfixOperation) { + visit((DatabendUnaryPostfixOperation) expr); + } else if (expr instanceof DatabendUnaryPrefixOperation) { + visit((DatabendUnaryPrefixOperation) expr); + } else if (expr instanceof DatabendBinaryOperation) { + visit((DatabendBinaryOperation) expr); + } else if (expr instanceof DatabendTableReference) { + visit((DatabendTableReference) expr); + } else if (expr instanceof DatabendFunctionOperation) { + visit((DatabendFunctionOperation) expr); + } else if (expr instanceof DatabendBetweenOperation) { + visit((DatabendBetweenOperation) expr); + } else if (expr instanceof DatabendInOperation) { + visit((DatabendInOperation) expr); + } else if (expr instanceof DatabendOrderByTerm) { + visit((DatabendOrderByTerm) expr); + } else if (expr instanceof DatabendAlias) { + visit((DatabendAlias) expr); + } else if (expr instanceof DatabendPostFixText) { + visit((DatabendPostFixText) expr); + } else if (expr instanceof DatabendConstant) { + visit((DatabendConstant) expr); + } else if (expr instanceof DatabendSelect) { + visit((DatabendSelect) expr); + } else if (expr instanceof DatabendJoin) { + visit((DatabendJoin) expr); + } else { + throw new AssertionError(expr); + } + } + + public void visit(DatabendColumnReference c) { + print(c); + } + + public void visit(DatabendUnaryPostfixOperation op) { + print(op); + visit(op.getExpr()); + } + + public void visit(DatabendUnaryPrefixOperation op) { + print(op); + visit(op.getExpr()); + } + + public void visit(DatabendBinaryOperation op) { + print(op); + visit(op.getLeft()); + visit(op.getRight()); + } + + public void visit(DatabendTableReference t) { + print(t); + } + + public void visit(DatabendFunctionOperation fun) { + print(fun); + visit(fun.getArgs()); + } + + public void visit(List expressions) { + for (DatabendExpression expression : expressions) { + visit(expression); + } + } + + public void visit(DatabendBetweenOperation op) { + print(op); + visit(op.getLeft()); + visit(op.getMiddle()); + visit(op.getRight()); + } + + public void visit(DatabendInOperation op) { + print(op); + visit(op.getLeft()); + visit(op.getRight()); + } + + public void visit(DatabendOrderByTerm op) { + print(op); + visit(op.getExpr()); + } + + public void visit(DatabendAlias op) { + print(op); + visit(op.getExpr()); + } + + public void visit(DatabendPostFixText postFixText) { + print(postFixText); + visit(postFixText.getExpr()); + } + + public void visit(DatabendConstant constant) { + print(constant); + } + + public void visit(DatabendSelect select) { + print(select.getWhereClause()); + } + + public void visit(DatabendJoin join) { + print(join.getOnCondition()); + } + + public String get() { + return sb.toString(); + } + + public static String asExpectedValues(DatabendExpression expr) { + DatabendExpectedValueVisitor v = new DatabendExpectedValueVisitor(); + v.visit(expr); + return v.get(); + } + +} diff --git a/src/sqlancer/databend/DatabendOptions.java b/src/sqlancer/databend/DatabendOptions.java index ebacd21cf..d38ee0d59 100644 --- a/src/sqlancer/databend/DatabendOptions.java +++ b/src/sqlancer/databend/DatabendOptions.java @@ -1,7 +1,5 @@ package sqlancer.databend; -import java.sql.SQLException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -9,17 +7,6 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.CompositeTestOracle; -import sqlancer.common.oracle.TestOracle; -import sqlancer.databend.DatabendOptions.DatabendOracleFactory; -import sqlancer.databend.DatabendProvider.DatabendGlobalState; -import sqlancer.databend.test.DatabendNoRECOracle; -import sqlancer.databend.test.DatabendQueryPartitioningAggregateTester; -import sqlancer.databend.test.DatabendQueryPartitioningDistinctTester; -import sqlancer.databend.test.DatabendQueryPartitioningGroupByTester; -import sqlancer.databend.test.DatabendQueryPartitioningHavingTester; -import sqlancer.databend.test.DatabendQueryPartitioningWhereTester; @Parameters(commandDescription = "Databend") public class DatabendOptions implements DBMSSpecificOptions { @@ -95,62 +82,6 @@ public class DatabendOptions implements DBMSSpecificOptions oracles = Arrays.asList(DatabendOracleFactory.QUERY_PARTITIONING); - public enum DatabendOracleFactory implements OracleFactory { - NOREC { - - @Override - public TestOracle create(DatabendGlobalState globalState) throws SQLException { - return new DatabendNoRECOracle(globalState); - } - - }, - HAVING { - @Override - public TestOracle create(DatabendGlobalState globalState) throws SQLException { - return new DatabendQueryPartitioningHavingTester(globalState); - } - }, - WHERE { - @Override - public TestOracle create(DatabendGlobalState globalState) throws SQLException { - return new DatabendQueryPartitioningWhereTester(globalState); - } - }, - GROUP_BY { - @Override - public TestOracle create(DatabendGlobalState globalState) throws SQLException { - return new DatabendQueryPartitioningGroupByTester(globalState); - } - }, - AGGREGATE { - - @Override - public TestOracle create(DatabendGlobalState globalState) throws SQLException { - return new DatabendQueryPartitioningAggregateTester(globalState); - } - - }, - DISTINCT { - @Override - public TestOracle create(DatabendGlobalState globalState) throws SQLException { - return new DatabendQueryPartitioningDistinctTester(globalState); - } - }, - QUERY_PARTITIONING { - @Override - public TestOracle create(DatabendGlobalState globalState) throws SQLException { - List oracles = new ArrayList<>(); - oracles.add(new DatabendQueryPartitioningWhereTester(globalState)); - oracles.add(new DatabendQueryPartitioningHavingTester(globalState)); - oracles.add(new DatabendQueryPartitioningAggregateTester(globalState)); - oracles.add(new DatabendQueryPartitioningDistinctTester(globalState)); - oracles.add(new DatabendQueryPartitioningGroupByTester(globalState)); - return new CompositeTestOracle(oracles, globalState); - } - }; - - } - @Override public List getTestOracleFactory() { return oracles; diff --git a/src/sqlancer/databend/DatabendOracleFactory.java b/src/sqlancer/databend/DatabendOracleFactory.java new file mode 100644 index 000000000..ea66bc353 --- /dev/null +++ b/src/sqlancer/databend/DatabendOracleFactory.java @@ -0,0 +1,93 @@ +package sqlancer.databend; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.databend.gen.DatabendNewExpressionGenerator; +import sqlancer.databend.test.DatabendPivotedQuerySynthesisOracle; +import sqlancer.databend.test.tlp.DatabendQueryPartitioningAggregateTester; +import sqlancer.databend.test.tlp.DatabendQueryPartitioningDistinctTester; +import sqlancer.databend.test.tlp.DatabendQueryPartitioningGroupByTester; +import sqlancer.databend.test.tlp.DatabendQueryPartitioningHavingTester; + +public enum DatabendOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create(DatabendProvider.DatabendGlobalState globalState) + throws SQLException { + DatabendNewExpressionGenerator gen = new DatabendNewExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(DatabendErrors.getExpressionErrors()) + .with("canceling statement due to statement timeout").build(); + return new NoRECOracle<>(globalState, gen, errors); + } + + }, + HAVING { + @Override + public TestOracle create(DatabendProvider.DatabendGlobalState globalState) + throws SQLException { + return new DatabendQueryPartitioningHavingTester(globalState); + } + }, + WHERE { + @Override + public TestOracle create(DatabendProvider.DatabendGlobalState globalState) + throws SQLException { + DatabendNewExpressionGenerator gen = new DatabendNewExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(DatabendErrors.getExpressionErrors()) + .with(DatabendErrors.getGroupByErrors()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + GROUP_BY { + @Override + public TestOracle create(DatabendProvider.DatabendGlobalState globalState) + throws SQLException { + return new DatabendQueryPartitioningGroupByTester(globalState); + } + }, + AGGREGATE { + @Override + public TestOracle create(DatabendProvider.DatabendGlobalState globalState) + throws SQLException { + return new DatabendQueryPartitioningAggregateTester(globalState); + } + + }, + DISTINCT { + @Override + public TestOracle create(DatabendProvider.DatabendGlobalState globalState) + throws SQLException { + return new DatabendQueryPartitioningDistinctTester(globalState); + } + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(DatabendProvider.DatabendGlobalState globalState) + throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(WHERE.create(globalState)); + oracles.add(HAVING.create(globalState)); + oracles.add(AGGREGATE.create(globalState)); + oracles.add(DISTINCT.create(globalState)); + oracles.add(GROUP_BY.create(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + }, + PQS { + @Override + public TestOracle create(DatabendProvider.DatabendGlobalState globalState) + throws Exception { + return new DatabendPivotedQuerySynthesisOracle(globalState); + } + } + +} diff --git a/src/sqlancer/databend/DatabendProvider.java b/src/sqlancer/databend/DatabendProvider.java index 9a94241d0..df7802027 100644 --- a/src/sqlancer/databend/DatabendProvider.java +++ b/src/sqlancer/databend/DatabendProvider.java @@ -20,9 +20,11 @@ import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLQueryProvider; import sqlancer.databend.DatabendProvider.DatabendGlobalState; +import sqlancer.databend.gen.DatabendDeleteGenerator; import sqlancer.databend.gen.DatabendInsertGenerator; import sqlancer.databend.gen.DatabendRandomQuerySynthesizer; import sqlancer.databend.gen.DatabendTableGenerator; +import sqlancer.databend.gen.DatabendViewGenerator; @AutoService(DatabaseProvider.class) public class DatabendProvider extends SQLProviderAdapter { @@ -33,13 +35,10 @@ public DatabendProvider() { public enum Action implements AbstractAction { - INSERT(DatabendInsertGenerator::getQuery), // - // TODO 等待databend实现update && delete - // DELETE(DatabendDeleteGenerator::generate), // + INSERT(DatabendInsertGenerator::getQuery), DELETE(DatabendDeleteGenerator::generate), + // TODO 等待databend实现update // UPDATE(DatabendUpdateGenerator::getQuery), // - - // CREATE_VIEW(DatabendViewGenerator::generate), //TODO 等待databend的create view语法 更加贴近mysql - EXPLAIN((g) -> { + CREATE_VIEW(DatabendViewGenerator::generate), EXPLAIN((g) -> { ExpectedErrors errors = new ExpectedErrors(); DatabendErrors.addExpressionErrors(errors); DatabendErrors.addGroupByErrors(errors); @@ -71,10 +70,10 @@ private static int mapActions(DatabendGlobalState globalState, Action a) { // TODO 等待databend实现update && delete // case UPDATE: // return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumUpdates + 1); - // case DELETE: - // return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumDeletes + 1); - // case CREATE_VIEW: //TODO 暂时关闭create view - // return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumViews + 1); + case DELETE: + return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumDeletes + 1); + case CREATE_VIEW: + return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumViews + 1); default: throw new AssertionError(a); } @@ -91,7 +90,7 @@ protected DatabendSchema readSchema() throws SQLException { @Override public void generateDatabase(DatabendGlobalState globalState) throws Exception { - for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { + for (int i = 0; i < Randomly.fromOptions(3, 4); i++) { boolean success; do { SQLQueryAdapter qt = new DatabendTableGenerator().getQuery(globalState); @@ -107,7 +106,7 @@ public void generateDatabase(DatabendGlobalState globalState) throws Exception { throw new IgnoreMeException(); } }); - se.executeStatements(); // 在已有的表格中插入数据,原先是增删改一些数据,除了insert和explan我都去掉了 + se.executeStatements(); // 增删改一些数据(按权重随机选取算法) } @Override @@ -129,15 +128,18 @@ public SQLConnection createDatabase(DatabendGlobalState globalState) throws SQLE try (Statement s = con.createStatement()) { s.execute("DROP DATABASE IF EXISTS " + databaseName); globalState.getState().logStatement("DROP DATABASE IF EXISTS " + databaseName); - } - try (Statement s = con.createStatement()) { s.execute("CREATE DATABASE " + databaseName); globalState.getState().logStatement("CREATE DATABASE " + databaseName); - } - try (Statement s = con.createStatement()) { s.execute("USE " + databaseName); globalState.getState().logStatement("USE " + databaseName); } + if (DatabendBugs.bug15569) { + con.close(); + String urlWithRetry = String.format( + "jdbc:mysql://%s:%d/%s?serverTimezone=UTC&useSSL=false&allowPublicKeyRetrieval=true&autoReconnect=true", + host, port, databaseName); + con = DriverManager.getConnection(urlWithRetry, username, password); + } return new SQLConnection(con); } diff --git a/src/sqlancer/databend/DatabendSchema.java b/src/sqlancer/databend/DatabendSchema.java index 98d7acafd..89738a1f3 100644 --- a/src/sqlancer/databend/DatabendSchema.java +++ b/src/sqlancer/databend/DatabendSchema.java @@ -1,30 +1,35 @@ package sqlancer.databend; +import static sqlancer.databend.DatabendSchema.DatabendDataType.INT; + import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.SQLConnection; -import sqlancer.common.DBMSCommon; import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractRowValue; import sqlancer.common.schema.AbstractSchema; import sqlancer.common.schema.AbstractTableColumn; import sqlancer.common.schema.AbstractTables; import sqlancer.common.schema.TableIndex; import sqlancer.databend.DatabendProvider.DatabendGlobalState; import sqlancer.databend.DatabendSchema.DatabendTable; +import sqlancer.databend.ast.DatabendConstant; public class DatabendSchema extends AbstractSchema { public enum DatabendDataType { - INT, VARCHAR, BOOLEAN, FLOAT, NULL; - // , DATE, TIMESTAMP + INT, VARCHAR, BOOLEAN, FLOAT, NULL, DATE, TIMESTAMP; public static DatabendDataType getRandomWithoutNull() { DatabendDataType dt; @@ -78,8 +83,8 @@ public static DatabendCompositeDataType getRandomWithoutNull() { break; case BOOLEAN: case VARCHAR: - // case DATE: - // case TIMESTAMP: + case DATE: + case TIMESTAMP: size = 0; break; default: @@ -118,10 +123,10 @@ public String toString() { } case BOOLEAN: return Randomly.fromOptions("BOOLEAN", "BOOL"); - // case TIMESTAMP: - // return Randomly.fromOptions("TIMESTAMP", "DATETIME"); - // case DATE: - // return Randomly.fromOptions("DATE"); + case DATE: + return Randomly.fromOptions("DATE"); + case TIMESTAMP: + return Randomly.fromOptions("TIMESTAMP", "DATETIME"); case NULL: return Randomly.fromOptions("NULL"); default: @@ -143,6 +148,7 @@ public DatabendColumn(String name, DatabendCompositeDataType columnType, boolean this.isNullable = isNullable; } + @Override public boolean isPrimaryKey() { return isPrimaryKey; } @@ -159,6 +165,61 @@ public DatabendTables(List tables) { super(tables); } + public DatabendRowValue getRandomRowValue(SQLConnection con) throws SQLException { + String rowValueQuery = String.format("SELECT %s FROM %s ORDER BY 1 LIMIT 1", columnNamesAsString( + c -> c.getTable().getName() + "." + c.getName() + " AS " + c.getTable().getName() + c.getName()), + tableNamesAsString()); + Map values = new HashMap<>(); + try (Statement s = con.createStatement()) { + ResultSet rs = s.executeQuery(rowValueQuery); + if (!rs.next()) { + throw new IgnoreMeException(); + // throw new AssertionError("could not find random row " + rowValueQuery + "\n"); + } + for (int i = 0; i < getColumns().size(); i++) { + DatabendColumn column = getColumns().get(i); + int columnIndex = rs.findColumn(column.getTable().getName() + column.getName()); + assert columnIndex == i + 1; + DatabendConstant constant; + if (rs.getString(columnIndex) == null) { + constant = DatabendConstant.createNullConstant(); + } else { + switch (column.getType().getPrimitiveDataType()) { + case INT: + constant = DatabendConstant.createIntConstant(rs.getLong(columnIndex)); + break; + case BOOLEAN: + constant = DatabendConstant.createBooleanConstant(rs.getBoolean(columnIndex)); + break; + case VARCHAR: + constant = DatabendConstant.createStringConstant(rs.getString(columnIndex)); + break; + case DATE: + constant = DatabendConstant.createDateConstant(rs.getLong(columnIndex)); + break; + case TIMESTAMP: + constant = DatabendConstant.createTimestampConstant(rs.getLong(columnIndex)); + default: + throw new IgnoreMeException(); + } + } + values.put(column, constant); + } + assert !rs.next(); + return new DatabendRowValue(this, values); + } catch (SQLException e) { + throw new IgnoreMeException(); + } + } + + } + + public static class DatabendRowValue extends AbstractRowValue { + + DatabendRowValue(DatabendTables tables, Map values) { + super(tables, values); + } + } public DatabendSchema(List databaseTables) { @@ -169,49 +230,72 @@ public DatabendTables getRandomTableNonEmptyTables() { return new DatabendTables(Randomly.nonEmptySubset(getDatabaseTables())); } + public DatabendTables getRandomTableNonEmptyAndViewTables() { + List tables = getDatabaseTables().stream().filter(t -> !t.isView()).collect(Collectors.toList()); + tables = Randomly.nonEmptySubset(tables); + return new DatabendTables(tables); + } + private static DatabendCompositeDataType getColumnType(String typeString) { - DatabendDataType primitiveType; - int size = -1; if (typeString.startsWith("DECIMAL")) { // Ugly hack return new DatabendCompositeDataType(DatabendDataType.FLOAT, 8); } - switch (typeString) { - case "INT": - primitiveType = DatabendDataType.INT; - size = 4; + if (typeString.startsWith("Nullable")) { // Ugly hack + String substring = typeString.substring(typeString.indexOf('(') + 1, typeString.indexOf(')')); + return getColumnTypeNormalCases(substring); + } + return getColumnTypeNormalCases(typeString); + } + + private static DatabendCompositeDataType getColumnTypeNormalCases(String typeString) { + DatabendDataType primitiveType; + int size = -1; + switch (typeString.toUpperCase()) { + case "BOOLEAN": + case "BOOL": + primitiveType = DatabendDataType.BOOLEAN; + size = 1; + break; + case "TINYINT": + case "INT8": + primitiveType = INT; + size = 1; break; case "SMALLINT": - primitiveType = DatabendDataType.INT; + case "INT16": + primitiveType = INT; size = 2; break; + case "INT": + case "INT32": + primitiveType = INT; + size = 4; + break; case "BIGINT": - primitiveType = DatabendDataType.INT; + case "INT64": + primitiveType = INT; size = 8; break; - case "TINYINT": - primitiveType = DatabendDataType.INT; - size = 1; - break; - case "VARCHAR": - primitiveType = DatabendDataType.VARCHAR; - break; case "FLOAT": + case "FLOAT32": primitiveType = DatabendDataType.FLOAT; size = 4; break; case "DOUBLE": + case "FLOAT64": primitiveType = DatabendDataType.FLOAT; size = 8; break; - case "BOOLEAN": - primitiveType = DatabendDataType.BOOLEAN; + case "DATE": + primitiveType = DatabendDataType.DATE; + break; + case "TIMESTAMP": + primitiveType = DatabendDataType.TIMESTAMP; + break; + case "VARCHAR": + case "STRING": + primitiveType = DatabendDataType.VARCHAR; break; - // case "DATE": - // primitiveType = DatabendDataType.DATE; - // break; - // case "TIMESTAMP": - // primitiveType = DatabendDataType.TIMESTAMP; - // break; case "NULL": primitiveType = DatabendDataType.NULL; break; @@ -237,11 +321,8 @@ public static DatabendSchema fromConnection(SQLConnection con, String databaseNa List databaseTables = new ArrayList<>(); List tableNames = getTableNames(con, databaseName); for (String tableName : tableNames) { - if (DBMSCommon.matchesIndexName(tableName)) { - continue; // TODO: unexpected? - } List databaseColumns = getTableColumns(con, tableName, databaseName); - boolean isView = tableName.startsWith("v"); + boolean isView = matchesViewName(tableName); DatabendTable t = new DatabendTable(tableName, databaseColumns, isView); for (DatabendColumn c : databaseColumns) { c.setTable(t); @@ -280,6 +361,9 @@ private static List getTableColumns(SQLConnection con, String ta while (rs.next()) { String columnName = rs.getString("column_name"); String dataType = rs.getString("data_type"); + if (dataType.contains("NULL")) { + dataType = dataType.substring(0, dataType.indexOf(' ')); + } boolean isNullable = rs.getBoolean("is_nullable"); // boolean isPrimaryKey = rs.getString("pk").contains("true"); boolean isPrimaryKey = false; // 没找到主键元数据 @@ -292,11 +376,7 @@ private static List getTableColumns(SQLConnection con, String ta } } } - // if (columns.stream().noneMatch(c -> c.isPrimaryKey())) { - // TODO: implement an option to enable/disable rowids - // columns.add(new DatabendColumn("rowid", new DatabendCompositeDataType(DatabendDataType.INT, 4), false, - // false)); - // } + return columns; } diff --git a/src/sqlancer/databend/DatabendToStringVisitor.java b/src/sqlancer/databend/DatabendToStringVisitor.java index b3669e4b7..05d09d7f2 100644 --- a/src/sqlancer/databend/DatabendToStringVisitor.java +++ b/src/sqlancer/databend/DatabendToStringVisitor.java @@ -1,7 +1,6 @@ package sqlancer.databend; import sqlancer.common.ast.newast.NewToStringVisitor; -import sqlancer.common.ast.newast.Node; import sqlancer.databend.ast.DatabendConstant; import sqlancer.databend.ast.DatabendExpression; import sqlancer.databend.ast.DatabendJoin; @@ -10,7 +9,7 @@ public class DatabendToStringVisitor extends NewToStringVisitor { @Override - public void visitSpecific(Node expr) { + public void visitSpecific(DatabendExpression expr) { if (expr instanceof DatabendConstant) { visit((DatabendConstant) expr); } else if (expr instanceof DatabendSelect) { @@ -77,9 +76,9 @@ private void visit(DatabendSelect select) { sb.append(" HAVING "); visit(select.getHavingClause()); } - if (!select.getOrderByExpressions().isEmpty()) { + if (!select.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - visit(select.getOrderByExpressions()); + visit(select.getOrderByClauses()); } if (select.getLimitClause() != null) { sb.append(" LIMIT "); @@ -91,7 +90,7 @@ private void visit(DatabendSelect select) { } } - public static String asString(Node expr) { + public static String asString(DatabendExpression expr) { DatabendToStringVisitor visitor = new DatabendToStringVisitor(); visitor.visit(expr); return visitor.get(); diff --git a/src/sqlancer/databend/ast/DatabendAggregateOperation.java b/src/sqlancer/databend/ast/DatabendAggregateOperation.java new file mode 100644 index 000000000..5070991d8 --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendAggregateOperation.java @@ -0,0 +1,46 @@ +package sqlancer.databend.ast; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.databend.DatabendSchema; + +public class DatabendAggregateOperation extends + DatabendFunctionOperation implements DatabendExpression { + public DatabendAggregateOperation(List args, DatabendAggregateFunction func) { + super(args, func); + } + + public enum DatabendAggregateFunction { + MAX(1), MIN(1), AVG(1, DatabendSchema.DatabendDataType.INT, DatabendSchema.DatabendDataType.FLOAT), COUNT(1), + SUM(1, DatabendSchema.DatabendDataType.INT, DatabendSchema.DatabendDataType.FLOAT), STDDEV_POP(1), COVAR_POP(1), + COVAR_SAMP(2); + // , *_IF, *_DISTINCT + + private int nrArgs; + private DatabendSchema.DatabendDataType[] dataTypes; + + DatabendAggregateFunction(int nrArgs, DatabendSchema.DatabendDataType... dataTypes) { + this.nrArgs = nrArgs; + this.dataTypes = dataTypes.clone(); + } + + public static DatabendAggregateFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public DatabendSchema.DatabendDataType getRandomType() { + if (dataTypes.length == 0) { + return Randomly.fromOptions(DatabendSchema.DatabendDataType.values()); + } else { + return Randomly.fromOptions(dataTypes); + } + } + + public int getNrArgs() { + return nrArgs; + } + + } + +} diff --git a/src/sqlancer/databend/ast/DatabendAlias.java b/src/sqlancer/databend/ast/DatabendAlias.java new file mode 100644 index 000000000..5f6827409 --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendAlias.java @@ -0,0 +1,9 @@ +package sqlancer.databend.ast; + +import sqlancer.common.ast.newast.NewAliasNode; + +public class DatabendAlias extends NewAliasNode implements DatabendExpression { + public DatabendAlias(DatabendExpression expr, String text) { + super(expr, text); + } +} diff --git a/src/sqlancer/databend/ast/DatabendBetweenOperation.java b/src/sqlancer/databend/ast/DatabendBetweenOperation.java new file mode 100644 index 000000000..a6a0134bf --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendBetweenOperation.java @@ -0,0 +1,39 @@ +package sqlancer.databend.ast; + +import sqlancer.common.ast.newast.NewBetweenOperatorNode; +import sqlancer.databend.DatabendSchema; + +public class DatabendBetweenOperation extends NewBetweenOperatorNode implements DatabendExpression { + public DatabendBetweenOperation(DatabendExpression left, DatabendExpression middle, DatabendExpression right, + boolean isTrue) { + super(left, middle, right, isTrue); + } + + public DatabendExpression getLeftExpr() { + return left; + } + + public DatabendExpression getMiddleExpr() { + return middle; + } + + public DatabendExpression getRightExpr() { + return right; + } + + @Override + public DatabendConstant getExpectedValue() { + DatabendBinaryComparisonOperation leftComparison = new DatabendBinaryComparisonOperation(getMiddleExpr(), + getLeftExpr(), DatabendBinaryComparisonOperation.DatabendBinaryComparisonOperator.LESS_EQUALS); + DatabendBinaryComparisonOperation rightComparison = new DatabendBinaryComparisonOperation(getLeftExpr(), + getRightExpr(), DatabendBinaryComparisonOperation.DatabendBinaryComparisonOperator.LESS_EQUALS); + return new DatabendBinaryLogicalOperation(leftComparison, rightComparison, + DatabendBinaryLogicalOperation.DatabendBinaryLogicalOperator.AND).getExpectedValue(); + } + + @Override + public DatabendSchema.DatabendDataType getExpectedType() { + return DatabendSchema.DatabendDataType.BOOLEAN; + } + +} diff --git a/src/sqlancer/databend/ast/DatabendBinaryArithmeticOperation.java b/src/sqlancer/databend/ast/DatabendBinaryArithmeticOperation.java index dfc44ee22..5554cc474 100644 --- a/src/sqlancer/databend/ast/DatabendBinaryArithmeticOperation.java +++ b/src/sqlancer/databend/ast/DatabendBinaryArithmeticOperation.java @@ -1,18 +1,50 @@ package sqlancer.databend.ast; +import java.util.function.BinaryOperator; + import sqlancer.common.ast.BinaryOperatorNode; import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.Node; +import sqlancer.databend.DatabendSchema.DatabendDataType; -public class DatabendBinaryArithmeticOperation extends NewBinaryOperatorNode { +public class DatabendBinaryArithmeticOperation extends NewBinaryOperatorNode + implements DatabendExpression { - public DatabendBinaryArithmeticOperation(Node left, Node right, + public DatabendBinaryArithmeticOperation(DatabendExpression left, DatabendExpression right, BinaryOperatorNode.Operator op) { super(left, right, op); } public enum DatabendBinaryArithmeticOperator implements BinaryOperatorNode.Operator { - ADDITION("+"), SUBTRACTION("-"), MULTIPLICATION("*"), DIVISION("/"), MODULO("%"); + ADDITION("+") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + return applyOperation(left, right, (l, r) -> l + r); + } + }, + SUBTRACTION("-") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + return applyOperation(left, right, (l, r) -> l - r); + } + }, + MULTIPLICATION("*") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + return applyOperation(left, right, (l, r) -> l * r); + } + }, + DIVISION("/") { // TODO databend不允许出现not(float),而a/b为float + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + return applyOperation(left, right, (l, r) -> r == 0 ? -1 : l / r); + } + }, + MODULO("%") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + return applyOperation(left, right, (l, r) -> r == 0 ? -1 : l % r); + } + }; private final String textRepresentation; @@ -20,10 +52,49 @@ public enum DatabendBinaryArithmeticOperator implements BinaryOperatorNode.Opera textRepresentation = text; } + public abstract DatabendConstant apply(DatabendConstant left, DatabendConstant right); + + public DatabendConstant applyOperation(DatabendConstant left, DatabendConstant right, BinaryOperator op) { + if (left.isNull() || right.isNull()) { + return DatabendConstant.createNullConstant(); + } else { + long leftVal = left.cast(DatabendDataType.INT).asInt(); + long rightVal = right.cast(DatabendDataType.INT).asInt(); + return DatabendConstant.createIntConstant(op.apply(leftVal, rightVal)); + } + } + @Override public String getTextRepresentation() { return textRepresentation; } } + public DatabendExpression getLeftExpr() { + return super.getLeft(); + } + + public DatabendExpression getRightExpr() { + return super.getRight(); + } + + public DatabendBinaryArithmeticOperator getOp() { + return (DatabendBinaryArithmeticOperator) op; + } + + @Override + public DatabendConstant getExpectedValue() { + DatabendConstant leftValue = getLeftExpr().getExpectedValue(); + DatabendConstant rightValue = getRightExpr().getExpectedValue(); + if (leftValue == null || rightValue == null) { + return null; + } + return getOp().apply(leftValue, rightValue); + } + + @Override + public DatabendDataType getExpectedType() { + return DatabendDataType.INT; + } + } diff --git a/src/sqlancer/databend/ast/DatabendBinaryComparisonOperation.java b/src/sqlancer/databend/ast/DatabendBinaryComparisonOperation.java index cf63ad17e..12fc8c376 100644 --- a/src/sqlancer/databend/ast/DatabendBinaryComparisonOperation.java +++ b/src/sqlancer/databend/ast/DatabendBinaryComparisonOperation.java @@ -2,18 +2,125 @@ import sqlancer.common.ast.BinaryOperatorNode; import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.Node; +import sqlancer.databend.DatabendSchema.DatabendDataType; -public class DatabendBinaryComparisonOperation extends NewBinaryOperatorNode { +public class DatabendBinaryComparisonOperation extends NewBinaryOperatorNode + implements DatabendExpression { - public DatabendBinaryComparisonOperation(Node left, Node right, + public DatabendBinaryComparisonOperation(DatabendExpression left, DatabendExpression right, DatabendBinaryComparisonOperator op) { super(left, right, op); } + public DatabendExpression getLeftExpression() { + return super.getLeft(); + } + + public DatabendExpression getRightExpression() { + return super.getRight(); + } + + public DatabendBinaryComparisonOperator getOp() { + return (DatabendBinaryComparisonOperator) op; + } + + @Override + public DatabendDataType getExpectedType() { + return DatabendDataType.BOOLEAN; + } + + @Override + public DatabendConstant getExpectedValue() { + DatabendConstant leftExpectedValue = getLeftExpression().getExpectedValue(); + DatabendConstant rightExpectedValue = getRightExpression().getExpectedValue(); + if (leftExpectedValue == null || rightExpectedValue == null) { + return null; + } + return getOp().apply(leftExpectedValue, rightExpectedValue); + } + public enum DatabendBinaryComparisonOperator implements BinaryOperatorNode.Operator { - EQUALS("="), IS_DISTINCT("IS DISTINCT FROM"), IS_NOT_DISTINCT("IS NOT DISTINCT FROM"), NOT_EQUALS("!="), - LESS("<"), LESS_EQUALS("<="), GREATER(">"), GREATER_EQUALS(">="); + EQUALS("=") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + return left.isEquals(right); + } + }, + NOT_EQUALS("!=") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + DatabendConstant isEquals = left.isEquals(right); + if (isEquals.isBoolean()) { + return DatabendConstant.createBooleanConstant(!isEquals.asBoolean()); + } + return isEquals; + } + }, + IS_DISTINCT("IS DISTINCT FROM") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + return DatabendConstant.createBooleanConstant(!IS_NOT_DISTINCT.apply(left, right).asBoolean()); + } + }, + IS_NOT_DISTINCT("IS NOT DISTINCT FROM") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + if (left.isNull()) { + return DatabendConstant.createBooleanConstant(right.isNull()); + } else if (right.isNull()) { + return DatabendConstant.createBooleanConstant(false); + } else { + return left.isEquals(right); + } + } + }, + LESS("<") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + return left.isLessThan(right); + } + }, + LESS_EQUALS("<=") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + DatabendConstant isLessThan = left.isLessThan(right); + if (isLessThan.isBoolean() && !isLessThan.asBoolean()) { + return left.isEquals(right); + } else { + return isLessThan; + } + } + }, + GREATER(">") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + DatabendConstant isEquals = left.isEquals(right); + if (isEquals.isBoolean() && isEquals.asBoolean()) { + return DatabendConstant.createBooleanConstant(false); + } else { + DatabendConstant less = left.isLessThan(right); + if (less.isNull()) { + return DatabendConstant.createNullConstant(); + } + return DatabendConstant.createBooleanConstant(!less.asBoolean()); + } + } + }, + GREATER_EQUALS(">=") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + DatabendConstant isEquals = left.isEquals(right); + if (isEquals.isBoolean() && isEquals.asBoolean()) { + return DatabendConstant.createBooleanConstant(true); + } else { + DatabendConstant less = left.isLessThan(right); + if (less.isNull()) { + return DatabendConstant.createNullConstant(); + } + return DatabendConstant.createBooleanConstant(!less.asBoolean()); + } + } + }; private final String textRepresentation; @@ -21,6 +128,8 @@ public enum DatabendBinaryComparisonOperator implements BinaryOperatorNode.Opera textRepresentation = text; } + public abstract DatabendConstant apply(DatabendConstant left, DatabendConstant right); + @Override public String getTextRepresentation() { return textRepresentation; diff --git a/src/sqlancer/databend/ast/DatabendBinaryLogicalOperation.java b/src/sqlancer/databend/ast/DatabendBinaryLogicalOperation.java index 355ec88a5..fd51e1eca 100644 --- a/src/sqlancer/databend/ast/DatabendBinaryLogicalOperation.java +++ b/src/sqlancer/databend/ast/DatabendBinaryLogicalOperation.java @@ -3,17 +3,91 @@ import sqlancer.Randomly; import sqlancer.common.ast.BinaryOperatorNode; import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.Node; +import sqlancer.databend.DatabendSchema.DatabendDataType; -public class DatabendBinaryLogicalOperation extends NewBinaryOperatorNode { +public class DatabendBinaryLogicalOperation extends NewBinaryOperatorNode + implements DatabendExpression { - public DatabendBinaryLogicalOperation(Node left, Node right, + public DatabendBinaryLogicalOperation(DatabendExpression left, DatabendExpression right, DatabendBinaryLogicalOperator op) { super(left, right, op); } + public DatabendExpression getLeftExpr() { + return super.getLeft(); + } + + public DatabendExpression getRightExpr() { + return super.getRight(); + } + + public DatabendBinaryLogicalOperator getOp() { + return (DatabendBinaryLogicalOperator) op; + } + + @Override + public DatabendConstant getExpectedValue() { + DatabendConstant leftValue = getLeftExpr().getExpectedValue(); + DatabendConstant rightValue = getRightExpr().getExpectedValue(); + if (leftValue == null || rightValue == null) { + return null; + } + return getOp().apply(leftValue, rightValue); + } + + @Override + public DatabendDataType getExpectedType() { + return DatabendDataType.BOOLEAN; + } + public enum DatabendBinaryLogicalOperator implements BinaryOperatorNode.Operator { - AND("AND", "and"), OR("OR", "or"); + AND("AND", "and") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + DatabendConstant leftVal = left.cast(DatabendDataType.BOOLEAN); + DatabendConstant rightVal = right.cast(DatabendDataType.BOOLEAN); + assert leftVal.isNull() || leftVal.isBoolean() : leftVal + "不是NULL也不是Boolean类型"; + assert rightVal.isNull() || rightVal.isBoolean() : rightVal + "不是NULL也不是Boolean类型"; + if (leftVal.isNull()) { + if (rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else { + if (rightVal.asBoolean()) { + return DatabendConstant.createNullConstant(); + } else { + return DatabendConstant.createBooleanConstant(false); + } + } + } else if (!leftVal.asBoolean()) { + return DatabendConstant.createBooleanConstant(false); + } + assert leftVal.asBoolean(); + if (rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else { + return DatabendConstant.createBooleanConstant(rightVal.asBoolean()); + } + } + }, + OR("OR", "or") { + @Override + public DatabendConstant apply(DatabendConstant left, DatabendConstant right) { + DatabendConstant leftVal = left.cast(DatabendDataType.BOOLEAN); + DatabendConstant rightVal = right.cast(DatabendDataType.BOOLEAN); + assert leftVal.isNull() || leftVal.isBoolean() : leftVal + "不是NULL也不是Boolean类型"; + assert rightVal.isNull() || rightVal.isBoolean() : rightVal + "不是NULL也不是Boolean类型"; + if (leftVal.isBoolean() && leftVal.asBoolean()) { + return DatabendConstant.createBooleanConstant(true); + } + if (rightVal.isBoolean() && rightVal.asBoolean()) { + return DatabendConstant.createBooleanConstant(true); + } + if (leftVal.isNull() || rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } + return DatabendConstant.createBooleanConstant(false); + } + }; private final String[] textRepresentations; @@ -33,6 +107,9 @@ public DatabendBinaryLogicalOperator getRandomOp() { public static DatabendBinaryLogicalOperator getRandom() { return Randomly.fromOptions(values()); } + + public abstract DatabendConstant apply(DatabendConstant left, DatabendConstant right); + } } diff --git a/src/sqlancer/databend/ast/DatabendBinaryOperation.java b/src/sqlancer/databend/ast/DatabendBinaryOperation.java new file mode 100644 index 000000000..c9fab806e --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendBinaryOperation.java @@ -0,0 +1,12 @@ +package sqlancer.databend.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class DatabendBinaryOperation extends NewBinaryOperatorNode implements DatabendExpression { + public DatabendBinaryOperation(DatabendExpression left, DatabendExpression right, + BinaryOperatorNode.Operator operator) { + super(left, right, operator); + } + +} diff --git a/src/sqlancer/databend/ast/DatabendCastOperation.java b/src/sqlancer/databend/ast/DatabendCastOperation.java index c1b3ffe66..07e7d7f9c 100644 --- a/src/sqlancer/databend/ast/DatabendCastOperation.java +++ b/src/sqlancer/databend/ast/DatabendCastOperation.java @@ -2,18 +2,39 @@ import sqlancer.common.ast.BinaryOperatorNode; import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.databend.DatabendSchema; +import sqlancer.databend.DatabendSchema.DatabendCompositeDataType; +import sqlancer.databend.DatabendSchema.DatabendDataType; -public class DatabendCastOperation extends NewUnaryPostfixOperatorNode { +public class DatabendCastOperation extends NewUnaryPostfixOperatorNode + implements DatabendExpression { - public DatabendCastOperation(Node expr, DatabendSchema.DatabendCompositeDataType type) { + DatabendDataType type; + + public DatabendCastOperation(DatabendExpression expr, DatabendCompositeDataType type) { super(expr, new BinaryOperatorNode.Operator() { @Override public String getTextRepresentation() { return "::" + type.toString(); } }); + this.type = type.getPrimitiveDataType(); + } + + DatabendExpression getExpression() { + return (DatabendExpression) getExpr(); } + @Override + public DatabendConstant getExpectedValue() { + DatabendConstant expectedValue = getExpression().getExpectedValue(); + if (expectedValue == null) { + return null; + } + return expectedValue.cast(type); + } + + @Override + public DatabendDataType getExpectedType() { + return type; + } } diff --git a/src/sqlancer/databend/ast/DatabendColumnReference.java b/src/sqlancer/databend/ast/DatabendColumnReference.java new file mode 100644 index 000000000..4b53ef4df --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendColumnReference.java @@ -0,0 +1,11 @@ +package sqlancer.databend.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.databend.DatabendSchema; + +public class DatabendColumnReference extends ColumnReferenceNode + implements DatabendExpression { + public DatabendColumnReference(DatabendSchema.DatabendColumn column) { + super(column); + } +} diff --git a/src/sqlancer/databend/ast/DatabendColumnValue.java b/src/sqlancer/databend/ast/DatabendColumnValue.java new file mode 100644 index 000000000..8b1992880 --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendColumnValue.java @@ -0,0 +1,31 @@ +package sqlancer.databend.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.databend.DatabendSchema.DatabendColumn; +import sqlancer.databend.DatabendSchema.DatabendDataType; + +public class DatabendColumnValue extends ColumnReferenceNode + implements DatabendExpression { + + private final DatabendConstant expectedValue; + + public DatabendColumnValue(DatabendColumn column, DatabendConstant value) { + super(column); + this.expectedValue = value; + } + + @Override + public DatabendConstant getExpectedValue() { + return expectedValue; + } + + @Override + public DatabendDataType getExpectedType() { + return getColumn().getType().getPrimitiveDataType(); + } + + public static DatabendColumnValue create(DatabendColumn column, DatabendConstant value) { + return new DatabendColumnValue(column, value); + } + +} diff --git a/src/sqlancer/databend/ast/DatabendConstant.java b/src/sqlancer/databend/ast/DatabendConstant.java index 39056aff9..919942a8a 100644 --- a/src/sqlancer/databend/ast/DatabendConstant.java +++ b/src/sqlancer/databend/ast/DatabendConstant.java @@ -3,13 +3,64 @@ import java.sql.Timestamp; import java.text.SimpleDateFormat; -import sqlancer.common.ast.newast.Node; +import sqlancer.databend.DatabendSchema.DatabendDataType; -public class DatabendConstant implements Node { +public abstract class DatabendConstant implements DatabendExpression { private DatabendConstant() { } + public boolean isNull() { + return false; + } + + public boolean isInt() { + return false; + } + + public boolean isBoolean() { + return false; + } + + public boolean isString() { + return false; + } + + public boolean isFloat() { + return false; + } + + public abstract DatabendConstant cast(DatabendDataType dataType); + + public boolean asBoolean() { + throw new UnsupportedOperationException(this.toString()); + } + + public long asInt() { + throw new UnsupportedOperationException(this.toString()); + } + + public String asString() { + throw new UnsupportedOperationException(this.toString()); + } + + public double asFloat() { + throw new UnsupportedOperationException(this.toString()); + } + + protected Timestamp truncateTimestamp(long val) { + // Databend supports `date` and `timestamp` type where the year cannot exceed `9999`, + // the value is truncated to ensure generate legitimate `date` and `timestamp` value. + long t = val % 253380000000000L; + return new Timestamp(t); + } + + public abstract DatabendConstant isEquals(DatabendConstant rightVal); + + public abstract DatabendConstant isLessThan(DatabendConstant rightVal); + + // public abstract String getTextRepresentation(); + public static class DatabendNullConstant extends DatabendConstant { @Override @@ -17,6 +68,35 @@ public String toString() { return "NULL"; } + @Override + public boolean isNull() { + return true; + } + + @Override + public DatabendConstant cast(DatabendDataType dataType) { + return DatabendConstant.createNullConstant(); + } + + @Override + public DatabendConstant isEquals(DatabendConstant rightVal) { + return DatabendConstant.createNullConstant(); + } + + @Override + public DatabendConstant isLessThan(DatabendConstant rightVal) { + return DatabendConstant.createNullConstant(); + } + + @Override + public DatabendDataType getExpectedType() { + return DatabendDataType.NULL; + } + + // @Override + // public DatabendConstant getExpectedValue() { + // return super.getExpectedValue(); + // } } public static class DatabendIntConstant extends DatabendConstant { @@ -36,13 +116,70 @@ public long getValue() { return value; } + @Override + public boolean isInt() { + return true; + } + + @Override + public DatabendConstant cast(DatabendDataType dataType) { + switch (dataType) { + case BOOLEAN: + return new DatabendBooleanConstant(value != 0); + case INT: + return this; + case VARCHAR: + return new DatabendStringConstant(String.valueOf(value)); + case DATE: + return new DatabendDateConstant(value); + case TIMESTAMP: + return new DatabendTimestampConstant(value); + default: + return null; + } + } + + @Override + public long asInt() { + return value; + } + + @Override + public DatabendConstant isEquals(DatabendConstant rightVal) { + if (rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return DatabendConstant.createBooleanConstant(value == rightVal.asInt()); + } else { + throw new AssertionError(rightVal); + } + + } + + @Override + public DatabendConstant isLessThan(DatabendConstant rightVal) { + if (rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return DatabendConstant.createBooleanConstant(value < rightVal.asInt()); + } else if (rightVal.isFloat()) { + return DatabendConstant.createBooleanConstant(value < rightVal.asFloat()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + public DatabendDataType getExpectedType() { + return DatabendDataType.INT; + } } - public static class DatabendDoubleConstant extends DatabendConstant { + public static class DatabendFloatConstant extends DatabendConstant { private final double value; - public DatabendDoubleConstant(double value) { + public DatabendFloatConstant(double value) { this.value = value; } @@ -50,6 +187,11 @@ public double getValue() { return value; } + @Override + public boolean isFloat() { + return true; + } + @Override public String toString() { if (value == Double.POSITIVE_INFINITY) { @@ -61,13 +203,51 @@ public String toString() { return String.valueOf(value); } + @Override + public DatabendConstant cast(DatabendDataType dataType) { + switch (dataType) { + case FLOAT: + return this; + case INT: + return DatabendConstant.createIntConstant((long) value); + case BOOLEAN: + return DatabendConstant.createBooleanConstant(value != 0); + case VARCHAR: + return DatabendConstant.createStringConstant(String.valueOf(value)); + default: + return null; + } + } + + @Override + public double asFloat() { + return value; + } + + @Override + public DatabendConstant isEquals(DatabendConstant rightVal) { + return null; + } + + @Override + public DatabendConstant isLessThan(DatabendConstant rightVal) { + if (rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return DatabendConstant.createBooleanConstant(value < rightVal.asInt()); + } else if (rightVal.isFloat()) { + return DatabendConstant.createBooleanConstant(value < rightVal.asFloat()); + } else { + throw new AssertionError(rightVal); + } + } } - public static class DatabendTextConstant extends DatabendConstant { + public static class DatabendStringConstant extends DatabendConstant { private final String value; - public DatabendTextConstant(String value) { + public DatabendStringConstant(String value) { this.value = value; } @@ -80,25 +260,68 @@ public String toString() { return "'" + value.replace("'", "''") + "'"; } - } - - public static class DatabendBitConstant extends DatabendConstant { - - private final String value; + @Override + public String asString() { + return value; + } - public DatabendBitConstant(long value) { - this.value = Long.toBinaryString(value); + @Override + public boolean isString() { + return true; } - public String getValue() { - return value; + @Override + public DatabendConstant cast(DatabendDataType dataType) { + switch (dataType) { + case VARCHAR: + return this; + case INT: + try { + return new DatabendIntConstant(Long.parseLong(value)); + } catch (NumberFormatException e) { + return new DatabendIntConstant(-1); + } + case BOOLEAN: + if ("false".contentEquals(value.toLowerCase())) { + return new DatabendBooleanConstant(false); + } else if ("true".contentEquals(value.toLowerCase())) { + return new DatabendBooleanConstant(true); + } else { + throw new AssertionError(String.format("string: %s, cannot be forced to boolean", value)); + } + case FLOAT: + try { + return new DatabendFloatConstant(Double.parseDouble(value)); + } catch (NumberFormatException e) { + return new DatabendFloatConstant(-1); + } + default: + return null; + } } @Override - public String toString() { - return "B'" + value + "'"; + public DatabendConstant isEquals(DatabendConstant rightVal) { + if (rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else if (rightVal.isString()) { + return DatabendConstant.createBooleanConstant(value.contentEquals(rightVal.asString())); + } else { + // TODO 可以比较 date和timestamp类型,待添加 + throw new AssertionError(rightVal); + } } + @Override + public DatabendConstant isLessThan(DatabendConstant rightVal) { + if (rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else if (rightVal.isString()) { + return DatabendConstant.createBooleanConstant(value.compareTo(rightVal.asString()) < 0); + } else { + throw new AssertionError(rightVal); + } + } } public static class DatabendDateConstant extends DatabendConstant { @@ -106,7 +329,7 @@ public static class DatabendDateConstant extends DatabendConstant { public String textRepr; public DatabendDateConstant(long val) { - Timestamp timestamp = new Timestamp(val); + Timestamp timestamp = truncateTimestamp(val); SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); textRepr = dateFormat.format(timestamp); } @@ -120,6 +343,20 @@ public String toString() { return String.format("DATE '%s'", textRepr); } + @Override + public DatabendConstant cast(DatabendDataType dataType) { + return null; + } + + @Override + public DatabendConstant isEquals(DatabendConstant rightVal) { + return null; + } + + @Override + public DatabendConstant isLessThan(DatabendConstant rightVal) { + return null; + } } public static class DatabendTimestampConstant extends DatabendConstant { @@ -127,7 +364,7 @@ public static class DatabendTimestampConstant extends DatabendConstant { public String textRepr; public DatabendTimestampConstant(long val) { - Timestamp timestamp = new Timestamp(val); + Timestamp timestamp = truncateTimestamp(val); SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); textRepr = dateFormat.format(timestamp); } @@ -141,6 +378,20 @@ public String toString() { return String.format("TIMESTAMP '%s'", textRepr); } + @Override + public DatabendConstant cast(DatabendDataType dataType) { + return null; + } + + @Override + public DatabendConstant isEquals(DatabendConstant rightVal) { + return null; + } + + @Override + public DatabendConstant isLessThan(DatabendConstant rightVal) { + return null; + } } public static class DatabendBooleanConstant extends DatabendConstant { @@ -160,33 +411,82 @@ public String toString() { return String.valueOf(value); } + @Override + public boolean asBoolean() { + return value; + } + + @Override + public boolean isBoolean() { + return true; + } + + @Override + public DatabendConstant cast(DatabendDataType dataType) { + switch (dataType) { + case BOOLEAN: + return this; + case INT: + return new DatabendIntConstant(value ? 1 : 0); + case FLOAT: + return new DatabendFloatConstant(value ? 1 : 0); + case VARCHAR: + return new DatabendStringConstant(value ? "1" : "0"); + default: + return null; + } + } + + @Override + public DatabendConstant isEquals(DatabendConstant rightVal) { + if (rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return DatabendConstant.createBooleanConstant(value == rightVal.asBoolean()); + } else { + throw new AssertionError(rightVal); + } + + } + + @Override + public DatabendConstant isLessThan(DatabendConstant rightVal) { + if (rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return DatabendConstant.createBooleanConstant((value ? 1 : 0) < (rightVal.asBoolean() ? 1 : 0)); + } else { + throw new AssertionError(rightVal); + } + } + } - public static Node createStringConstant(String text) { - return new DatabendTextConstant(text); + public static DatabendConstant createStringConstant(String text) { + return new DatabendStringConstant(text); } - public static Node createFloatConstant(double val) { - return new DatabendDoubleConstant(val); + public static DatabendConstant createFloatConstant(double val) { + return new DatabendFloatConstant(val); } - public static Node createIntConstant(long val) { + public static DatabendConstant createIntConstant(long val) { return new DatabendIntConstant(val); } - public static Node createNullConstant() { + public static DatabendConstant createNullConstant() { return new DatabendNullConstant(); } - public static Node createBooleanConstant(boolean val) { + public static DatabendConstant createBooleanConstant(boolean val) { return new DatabendBooleanConstant(val); } - public static Node createDateConstant(long integer) { + public static DatabendConstant createDateConstant(long integer) { return new DatabendDateConstant(integer); } - public static Node createTimestampConstant(long integer) { + public static DatabendConstant createTimestampConstant(long integer) { return new DatabendTimestampConstant(integer); } diff --git a/src/sqlancer/databend/ast/DatabendExpression.java b/src/sqlancer/databend/ast/DatabendExpression.java index 31b9bf859..d71e64ed7 100644 --- a/src/sqlancer/databend/ast/DatabendExpression.java +++ b/src/sqlancer/databend/ast/DatabendExpression.java @@ -1,5 +1,16 @@ package sqlancer.databend.ast; -public interface DatabendExpression { +import sqlancer.common.ast.newast.Expression; +import sqlancer.databend.DatabendSchema.DatabendColumn; +import sqlancer.databend.DatabendSchema.DatabendDataType; +public interface DatabendExpression extends Expression { + + default DatabendDataType getExpectedType() { + return null; + } + + default DatabendConstant getExpectedValue() { + return null; + } } diff --git a/src/sqlancer/databend/ast/DatabendFunctionOperation.java b/src/sqlancer/databend/ast/DatabendFunctionOperation.java new file mode 100644 index 000000000..69cb16175 --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendFunctionOperation.java @@ -0,0 +1,11 @@ +package sqlancer.databend.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class DatabendFunctionOperation extends NewFunctionNode implements DatabendExpression { + public DatabendFunctionOperation(List args, F func) { + super(args, func); + } +} diff --git a/src/sqlancer/databend/ast/DatabendInOperation.java b/src/sqlancer/databend/ast/DatabendInOperation.java new file mode 100644 index 000000000..f382aae04 --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendInOperation.java @@ -0,0 +1,52 @@ +package sqlancer.databend.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; +import sqlancer.databend.DatabendSchema; + +public class DatabendInOperation extends NewInOperatorNode implements DatabendExpression { + + private final DatabendExpression leftExpr; + private final List rightExpr; + + public DatabendInOperation(DatabendExpression left, List right, boolean isNegated) { + super(left, right, isNegated); + this.leftExpr = left; + this.rightExpr = right; + } + + @Override + public DatabendSchema.DatabendDataType getExpectedType() { + return DatabendSchema.DatabendDataType.BOOLEAN; + } + + @Override + public DatabendConstant getExpectedValue() { + DatabendConstant leftValue = leftExpr.getExpectedValue(); + if (leftValue == null) { + return null; + } + if (leftValue.isNull()) { + return DatabendConstant.createNullConstant(); + } + boolean isNull = false; + for (DatabendExpression expr : rightExpr) { + DatabendConstant rightValue = expr.getExpectedValue(); + if (rightValue == null) { + return null; + } + if (rightValue.isNull()) { + isNull = true; + } else if (rightValue.isEquals(leftValue).isBoolean() && rightValue.isEquals(leftValue).asBoolean()) { + return DatabendConstant.createBooleanConstant(!isNegated()); + } + } + + if (isNull) { + return DatabendConstant.createNullConstant(); + } else { + return DatabendConstant.createBooleanConstant(isNegated()); + } + } +} diff --git a/src/sqlancer/databend/ast/DatabendJoin.java b/src/sqlancer/databend/ast/DatabendJoin.java index 899929873..9386c1fe1 100644 --- a/src/sqlancer/databend/ast/DatabendJoin.java +++ b/src/sqlancer/databend/ast/DatabendJoin.java @@ -4,7 +4,7 @@ import java.util.List; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; +import sqlancer.common.ast.newast.Join; import sqlancer.common.ast.newast.TableReferenceNode; import sqlancer.databend.DatabendProvider.DatabendGlobalState; import sqlancer.databend.DatabendSchema; @@ -12,12 +12,12 @@ import sqlancer.databend.DatabendSchema.DatabendTable; import sqlancer.databend.gen.DatabendNewExpressionGenerator; -public class DatabendJoin implements Node { +public class DatabendJoin implements DatabendExpression, Join { - private final TableReferenceNode leftTable; - private final TableReferenceNode rightTable; + private final DatabendTableReference leftTable; + private final DatabendTableReference rightTable; private final JoinType joinType; - private final Node onCondition; + private DatabendExpression onCondition; private OuterType outerType; public enum JoinType { @@ -36,9 +36,8 @@ public static OuterType getRandom() { } } - public DatabendJoin(TableReferenceNode leftTable, - TableReferenceNode rightTable, JoinType joinType, - Node whereCondition) { + public DatabendJoin(DatabendTableReference leftTable, DatabendTableReference rightTable, JoinType joinType, + DatabendExpression whereCondition) { this.leftTable = leftTable; this.rightTable = rightTable; this.joinType = joinType; @@ -57,7 +56,7 @@ public JoinType getJoinType() { return joinType; } - public Node getOnCondition() { + public DatabendExpression getOnCondition() { return onCondition; } @@ -69,12 +68,11 @@ public OuterType getOuterType() { return outerType; } - public static List> getJoins( - List> tableList, DatabendGlobalState globalState) { - List> joinExpressions = new ArrayList<>(); + public static List getJoins(List tableList, DatabendGlobalState globalState) { + List joinExpressions = new ArrayList<>(); while (tableList.size() >= 2 && Randomly.getBooleanWithRatherLowProbability()) { - TableReferenceNode leftTable = tableList.remove(0); - TableReferenceNode rightTable = tableList.remove(0); + DatabendTableReference leftTable = tableList.remove(0); + DatabendTableReference rightTable = tableList.remove(0); List columns = new ArrayList<>(leftTable.getTable().getColumns()); columns.addAll(rightTable.getTable().getColumns()); DatabendNewExpressionGenerator joinGen = new DatabendNewExpressionGenerator(globalState) @@ -103,26 +101,30 @@ public static List> getJoins( return joinExpressions; } - public static DatabendJoin createRightOuterJoin(TableReferenceNode left, - TableReferenceNode right, Node predicate) { + public static DatabendJoin createRightOuterJoin(DatabendTableReference left, DatabendTableReference right, + DatabendExpression predicate) { return new DatabendJoin(left, right, JoinType.RIGHT, predicate); } - public static DatabendJoin createLeftOuterJoin(TableReferenceNode left, - TableReferenceNode right, Node predicate) { + public static DatabendJoin createLeftOuterJoin(DatabendTableReference left, DatabendTableReference right, + DatabendExpression predicate) { return new DatabendJoin(left, right, JoinType.LEFT, predicate); } - public static DatabendJoin createInnerJoin(TableReferenceNode left, - TableReferenceNode right, Node predicate) { + public static DatabendJoin createInnerJoin(DatabendTableReference left, DatabendTableReference right, + DatabendExpression predicate) { return new DatabendJoin(left, right, JoinType.INNER, predicate); } - public static Node createNaturalJoin(TableReferenceNode left, - TableReferenceNode right, OuterType naturalJoinType) { + public static DatabendJoin createNaturalJoin(DatabendTableReference left, DatabendTableReference right, + OuterType naturalJoinType) { DatabendJoin join = new DatabendJoin(left, right, JoinType.NATURAL, null); join.setOuterType(naturalJoinType); return join; } + @Override + public void setOnClause(DatabendExpression onClause) { + onCondition = onClause; + } } diff --git a/src/sqlancer/databend/ast/DatabendLikeOperation.java b/src/sqlancer/databend/ast/DatabendLikeOperation.java index 2755970e7..f40d804bb 100644 --- a/src/sqlancer/databend/ast/DatabendLikeOperation.java +++ b/src/sqlancer/databend/ast/DatabendLikeOperation.java @@ -1,17 +1,49 @@ package sqlancer.databend.ast; +import sqlancer.LikeImplementationHelper; import sqlancer.Randomly; import sqlancer.common.ast.BinaryOperatorNode; import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.Node; +import sqlancer.databend.DatabendSchema.DatabendDataType; -public class DatabendLikeOperation extends NewBinaryOperatorNode { +public class DatabendLikeOperation extends NewBinaryOperatorNode implements DatabendExpression { - public DatabendLikeOperation(Node left, Node right, - DatabendLikeOperator op) { + public DatabendLikeOperation(DatabendExpression left, DatabendExpression right, DatabendLikeOperator op) { super(left, right, op); } + @Override + public DatabendDataType getExpectedType() { + return DatabendDataType.BOOLEAN; + } + + public DatabendExpression getLeftExpr() { + return super.getLeft(); + } + + public DatabendExpression getRightExpr() { + return super.getRight(); + } + + public DatabendLikeOperator getOp() { + return (DatabendLikeOperator) op; + } + + @Override + public DatabendConstant getExpectedValue() { + DatabendConstant leftVal = getLeftExpr().getExpectedValue(); + DatabendConstant rightVal = getRightExpr().getExpectedValue(); + if (leftVal == null || rightVal == null) { + return null; + } + if (leftVal.isNull() || rightVal.isNull()) { + return DatabendConstant.createNullConstant(); + } else { + boolean result = LikeImplementationHelper.match(leftVal.asString(), rightVal.asString(), 0, 0, true); + return DatabendConstant.createBooleanConstant(result); + } + } + public enum DatabendLikeOperator implements BinaryOperatorNode.Operator { LIKE_OPERATOR("LIKE", "like"); diff --git a/src/sqlancer/databend/ast/DatabendOrderByTerm.java b/src/sqlancer/databend/ast/DatabendOrderByTerm.java new file mode 100644 index 000000000..7492c7c33 --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendOrderByTerm.java @@ -0,0 +1,9 @@ +package sqlancer.databend.ast; + +import sqlancer.common.ast.newast.NewOrderingTerm; + +public class DatabendOrderByTerm extends NewOrderingTerm implements DatabendExpression { + public DatabendOrderByTerm(DatabendExpression expr, Ordering ordering) { + super(expr, ordering); + } +} diff --git a/src/sqlancer/databend/ast/DatabendPostFixText.java b/src/sqlancer/databend/ast/DatabendPostFixText.java new file mode 100644 index 000000000..8f854fbe1 --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendPostFixText.java @@ -0,0 +1,9 @@ +package sqlancer.databend.ast; + +import sqlancer.common.ast.newast.NewPostfixTextNode; + +public class DatabendPostFixText extends NewPostfixTextNode implements DatabendExpression { + public DatabendPostFixText(DatabendExpression expr, String text) { + super(expr, text); + } +} diff --git a/src/sqlancer/databend/ast/DatabendSelect.java b/src/sqlancer/databend/ast/DatabendSelect.java index b3ccc08f5..adc75d1cc 100644 --- a/src/sqlancer/databend/ast/DatabendSelect.java +++ b/src/sqlancer/databend/ast/DatabendSelect.java @@ -1,9 +1,16 @@ package sqlancer.databend.ast; +import java.util.List; +import java.util.stream.Collectors; + import sqlancer.common.ast.SelectBase; -import sqlancer.common.ast.newast.Node; +import sqlancer.common.ast.newast.Select; +import sqlancer.databend.DatabendSchema.DatabendColumn; +import sqlancer.databend.DatabendSchema.DatabendTable; +import sqlancer.databend.DatabendToStringVisitor; -public class DatabendSelect extends SelectBase> implements Node { +public class DatabendSelect extends SelectBase + implements DatabendExpression, Select { private boolean isDistinct; @@ -15,4 +22,20 @@ public boolean isDistinct() { return isDistinct; } + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (DatabendExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (DatabendJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return DatabendToStringVisitor.asString(this); + } } diff --git a/src/sqlancer/databend/ast/DatabendTableReference.java b/src/sqlancer/databend/ast/DatabendTableReference.java new file mode 100644 index 000000000..7a98d3877 --- /dev/null +++ b/src/sqlancer/databend/ast/DatabendTableReference.java @@ -0,0 +1,11 @@ +package sqlancer.databend.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.databend.DatabendSchema; + +public class DatabendTableReference extends TableReferenceNode + implements DatabendExpression { + public DatabendTableReference(DatabendSchema.DatabendTable table) { + super(table); + } +} diff --git a/src/sqlancer/databend/ast/DatabendUnaryPostfixOperation.java b/src/sqlancer/databend/ast/DatabendUnaryPostfixOperation.java index e01c35282..119f93e5c 100644 --- a/src/sqlancer/databend/ast/DatabendUnaryPostfixOperation.java +++ b/src/sqlancer/databend/ast/DatabendUnaryPostfixOperation.java @@ -3,23 +3,35 @@ import sqlancer.Randomly; import sqlancer.common.ast.BinaryOperatorNode; import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; -import sqlancer.common.ast.newast.Node; import sqlancer.databend.DatabendSchema.DatabendDataType; -public class DatabendUnaryPostfixOperation extends NewUnaryPostfixOperatorNode { +public class DatabendUnaryPostfixOperation extends NewUnaryPostfixOperatorNode + implements DatabendExpression { - // private final Node expr; - // private final DatabendUnaryPostfixOperator op; - private boolean negate; - - public DatabendUnaryPostfixOperation(Node expr, DatabendUnaryPostfixOperator op, - boolean negate) { + public DatabendUnaryPostfixOperation(DatabendExpression expr, DatabendUnaryPostfixOperator op) { super(expr, op); - setNegate(negate); } - public DatabendUnaryPostfixOperation(Node expr, DatabendUnaryPostfixOperator op) { - super(expr, op); + public DatabendExpression getExpression() { + return (DatabendExpression) getExpr(); + } + + public DatabendUnaryPostfixOperator getOp() { + return (DatabendUnaryPostfixOperator) op; + } + + @Override + public DatabendDataType getExpectedType() { + return DatabendDataType.BOOLEAN; + } + + @Override + public DatabendConstant getExpectedValue() { + DatabendConstant expectedValue = getExpression().getExpectedValue(); + if (expectedValue == null) { + return null; + } + return getOp().apply(expectedValue); } public enum DatabendUnaryPostfixOperator implements BinaryOperatorNode.Operator { @@ -28,12 +40,22 @@ public enum DatabendUnaryPostfixOperator implements BinaryOperatorNode.Operator public DatabendDataType[] getInputDataTypes() { return DatabendDataType.values(); } + + @Override + public DatabendConstant apply(DatabendConstant value) { + return DatabendConstant.createBooleanConstant(value.isNull()); + } }, IS_NOT_NULL("IS NOT NULL") { @Override public DatabendDataType[] getInputDataTypes() { return DatabendDataType.values(); } + + @Override + public DatabendConstant apply(DatabendConstant value) { + return DatabendConstant.createBooleanConstant(!value.isNull()); + } }; // IS @@ -54,19 +76,7 @@ public String getTextRepresentation() { public abstract DatabendDataType[] getInputDataTypes(); - } - - public boolean isNegated() { - return negate; - } - - public void setNegate(boolean negate) { - this.negate = negate; - } - - // @Override - public Node getExpression() { - return getExpr(); + public abstract DatabendConstant apply(DatabendConstant value); } @Override @@ -74,8 +84,4 @@ public String getOperatorRepresentation() { return this.op.getTextRepresentation(); } - // @Override - // public OperatorKind getOperatorKind() { - // return OperatorKind.POSTFIX; - // } } diff --git a/src/sqlancer/databend/ast/DatabendUnaryPrefixOperation.java b/src/sqlancer/databend/ast/DatabendUnaryPrefixOperation.java index 69c95a745..ea6b65f07 100644 --- a/src/sqlancer/databend/ast/DatabendUnaryPrefixOperation.java +++ b/src/sqlancer/databend/ast/DatabendUnaryPrefixOperation.java @@ -3,72 +3,87 @@ import sqlancer.Randomly; import sqlancer.common.ast.BinaryOperatorNode; import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; -import sqlancer.common.ast.newast.Node; import sqlancer.databend.DatabendSchema.DatabendDataType; -public class DatabendUnaryPrefixOperation extends NewUnaryPrefixOperatorNode { +public class DatabendUnaryPrefixOperation extends NewUnaryPrefixOperatorNode + implements DatabendExpression { - // private final Node expr; - // private final DatabendUnaryPrefixOperator op; - // private boolean negate; - - public DatabendUnaryPrefixOperation(Node expr, DatabendUnaryPrefixOperator op) { + public DatabendUnaryPrefixOperation(DatabendExpression expr, DatabendUnaryPrefixOperator op) { super(expr, op); } - // public DatabendUnaryPrefixOperation(Node expr, DatabendUnaryPrefixOperator op, boolean - // negate) { - // super(expr,op); - // setNegate(negate); - // } + public DatabendExpression getExpression() { + return (DatabendExpression) getExpr(); + } - // void setNegate(boolean negate){ - // this.negate = negate; - // } + public DatabendUnaryPrefixOperator getOp() { + return (DatabendUnaryPrefixOperator) op; + } - // @Override - public Node getExpression() { - return getExpr(); + @Override + public DatabendDataType getExpectedType() { + return getOp().getExpressionType(getExpression()); } - // @Override - // public OperatorKind getOperatorKind() { - // return OperatorKind.PREFIX; - // } + @Override + public DatabendConstant getExpectedValue() { + DatabendConstant expectedValue = getExpression().getExpectedValue(); + if (expectedValue == null) { + return null; + } + return getOp().apply(expectedValue); + } public enum DatabendUnaryPrefixOperator implements BinaryOperatorNode.Operator { NOT("NOT", DatabendDataType.BOOLEAN, DatabendDataType.INT) { @Override - public DatabendDataType getExpressionType() { + public DatabendDataType getExpressionType(DatabendExpression expr) { return DatabendDataType.BOOLEAN; } @Override - protected DatabendConstant getExpectedValue(DatabendConstant expectedValue) { - return null; // TODO + protected DatabendConstant apply(DatabendConstant value) { + if (value.isNull()) { + return DatabendConstant.createNullConstant(); + } else { + return DatabendConstant.createBooleanConstant(!value.cast(DatabendDataType.BOOLEAN).asBoolean()); + } } }, UNARY_PLUS("+", DatabendDataType.INT) { @Override - public DatabendDataType getExpressionType() { - return DatabendDataType.INT; + public DatabendDataType getExpressionType(DatabendExpression expr) { + return expr.getExpectedType(); } @Override - protected DatabendConstant getExpectedValue(DatabendConstant expectedValue) { - return expectedValue; + protected DatabendConstant apply(DatabendConstant value) { + return value; } }, UNARY_MINUS("-", DatabendDataType.INT) { @Override - public DatabendDataType getExpressionType() { - return DatabendDataType.INT; + public DatabendDataType getExpressionType(DatabendExpression expr) { + return expr.getExpectedType(); } @Override - protected DatabendConstant getExpectedValue(DatabendConstant expectedValue) { - return null; + protected DatabendConstant apply(DatabendConstant value) { + if (value.isNull()) { + return DatabendConstant.createNullConstant(); + } + try { + if (value.isInt()) { + return DatabendConstant.createIntConstant(-value.asInt()); + } else if (value.isFloat()) { + return DatabendConstant.createFloatConstant(-value.asFloat()); + } else { + return null; + } + } catch (UnsupportedOperationException e) { + return null; + } } }; @@ -80,13 +95,13 @@ protected DatabendConstant getExpectedValue(DatabendConstant expectedValue) { this.dataTypes = dataTypes.clone(); } - public abstract DatabendDataType getExpressionType(); + public abstract DatabendDataType getExpressionType(DatabendExpression expr); public DatabendDataType getRandomInputDataTypes() { return Randomly.fromOptions(dataTypes); } - protected abstract DatabendConstant getExpectedValue(DatabendConstant expectedValue); + protected abstract DatabendConstant apply(DatabendConstant value); @Override public String getTextRepresentation() { diff --git a/src/sqlancer/databend/gen/DatabendDeleteGenerator.java b/src/sqlancer/databend/gen/DatabendDeleteGenerator.java new file mode 100644 index 000000000..256030409 --- /dev/null +++ b/src/sqlancer/databend/gen/DatabendDeleteGenerator.java @@ -0,0 +1,29 @@ +package sqlancer.databend.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.databend.DatabendErrors; +import sqlancer.databend.DatabendProvider.DatabendGlobalState; +import sqlancer.databend.DatabendSchema.DatabendDataType; +import sqlancer.databend.DatabendToStringVisitor; + +public final class DatabendDeleteGenerator { + + private DatabendDeleteGenerator() { + } + + public static SQLQueryAdapter generate(DatabendGlobalState globalState) { + StringBuilder sb = new StringBuilder("DELETE FROM "); + ExpectedErrors errors = new ExpectedErrors(); + sb.append(globalState.getSchema().getRandomTable(t -> !t.isView()).getName()); + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + sb.append(DatabendToStringVisitor.asString( + new DatabendNewExpressionGenerator(globalState).generateExpression(DatabendDataType.BOOLEAN))); + DatabendErrors.addExpressionErrors(errors); + } + return new SQLQueryAdapter(sb.toString(), errors); + } + +} diff --git a/src/sqlancer/databend/gen/DatabendNewExpressionGenerator.java b/src/sqlancer/databend/gen/DatabendNewExpressionGenerator.java index 43d38939f..8e52713b8 100644 --- a/src/sqlancer/databend/gen/DatabendNewExpressionGenerator.java +++ b/src/sqlancer/databend/gen/DatabendNewExpressionGenerator.java @@ -2,50 +2,121 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; import sqlancer.Randomly; -import sqlancer.common.ast.newast.NewBetweenOperatorNode; -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.NewInOperatorNode; -import sqlancer.common.ast.newast.Node; +import sqlancer.common.ast.newast.NewOrderingTerm; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; import sqlancer.common.gen.TypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; +import sqlancer.databend.DatabendBugs; import sqlancer.databend.DatabendProvider.DatabendGlobalState; import sqlancer.databend.DatabendSchema.DatabendColumn; +import sqlancer.databend.DatabendSchema.DatabendCompositeDataType; import sqlancer.databend.DatabendSchema.DatabendDataType; +import sqlancer.databend.DatabendSchema.DatabendRowValue; +import sqlancer.databend.DatabendSchema.DatabendTable; +import sqlancer.databend.DatabendToStringVisitor; +import sqlancer.databend.ast.DatabendAggregateOperation; +import sqlancer.databend.ast.DatabendAggregateOperation.DatabendAggregateFunction; +import sqlancer.databend.ast.DatabendBetweenOperation; import sqlancer.databend.ast.DatabendBinaryArithmeticOperation; import sqlancer.databend.ast.DatabendBinaryArithmeticOperation.DatabendBinaryArithmeticOperator; import sqlancer.databend.ast.DatabendBinaryComparisonOperation; import sqlancer.databend.ast.DatabendBinaryComparisonOperation.DatabendBinaryComparisonOperator; import sqlancer.databend.ast.DatabendBinaryLogicalOperation; import sqlancer.databend.ast.DatabendBinaryLogicalOperation.DatabendBinaryLogicalOperator; +import sqlancer.databend.ast.DatabendCastOperation; +import sqlancer.databend.ast.DatabendColumnReference; +import sqlancer.databend.ast.DatabendColumnValue; import sqlancer.databend.ast.DatabendConstant; import sqlancer.databend.ast.DatabendExpression; +import sqlancer.databend.ast.DatabendInOperation; +import sqlancer.databend.ast.DatabendJoin; import sqlancer.databend.ast.DatabendLikeOperation; +import sqlancer.databend.ast.DatabendOrderByTerm; +import sqlancer.databend.ast.DatabendPostFixText; +import sqlancer.databend.ast.DatabendSelect; +import sqlancer.databend.ast.DatabendTableReference; import sqlancer.databend.ast.DatabendUnaryPostfixOperation; import sqlancer.databend.ast.DatabendUnaryPostfixOperation.DatabendUnaryPostfixOperator; import sqlancer.databend.ast.DatabendUnaryPrefixOperation; import sqlancer.databend.ast.DatabendUnaryPrefixOperation.DatabendUnaryPrefixOperator; public class DatabendNewExpressionGenerator - extends TypedExpressionGenerator, DatabendColumn, DatabendDataType> { + extends TypedExpressionGenerator + implements NoRECGenerator, + TLPWhereGenerator { private final DatabendGlobalState globalState; + private List tables; + + private final int maxDepth; private boolean allowAggregateFunctions; + private DatabendRowValue rowValue; + + private Set columnOfLeafNode; + + public DatabendNewExpressionGenerator setRowValue(DatabendRowValue rowValue) { + this.rowValue = rowValue; + return this; + } + + public void setColumnOfLeafNode(Set columnOfLeafNode) { + this.columnOfLeafNode = columnOfLeafNode; + } public DatabendNewExpressionGenerator(DatabendGlobalState globalState) { this.globalState = globalState; + this.maxDepth = globalState.getOptions().getMaxExpressionDepth(); } @Override - public Node generateLeafNode(DatabendDataType dataType) { - return generateConstant(dataType); + public DatabendExpression generateLeafNode(DatabendDataType dataType) { + if (Randomly.getBoolean()) { + return generateConstant(dataType); + } else { + if (filterColumns(dataType).isEmpty()) { + return generateConstant(dataType); + } else { + return createColumnOfType(dataType); + } + } + } + + final List filterColumns(DatabendDataType dataType) { + if (columns == null) { + return Collections.emptyList(); + } else { + return columns.stream().filter(c -> c.getType().getPrimitiveDataType() == dataType) + .collect(Collectors.toList()); + } + } + + private DatabendExpression createColumnOfType(DatabendDataType type) { + List columns = filterColumns(type); + DatabendColumn column = Randomly.fromList(columns); + DatabendConstant value = rowValue == null ? null : rowValue.getValues().get(column); + if (columnOfLeafNode != null) { + columnOfLeafNode.add(DatabendColumnValue.create(column, value)); + } + return DatabendColumnValue.create(column, value); + } + + public List generateOrderBy() { + List randomColumns = Randomly.subset(columns); + return randomColumns.stream().map( + c -> new DatabendOrderByTerm(new DatabendColumnValue(c, null), NewOrderingTerm.Ordering.getRandom())) + .collect(Collectors.toList()); } @Override - protected Node generateExpression(DatabendDataType type, int depth) { - if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + protected DatabendExpression generateExpression(DatabendDataType type, int depth) { + if (Randomly.getBooleanWithRatherLowProbability() || depth >= maxDepth) { return generateLeafNode(type); } @@ -56,6 +127,8 @@ protected Node generateExpression(DatabendDataType type, int return generateIntExpression(depth); case FLOAT: case VARCHAR: + case DATE: + case TIMESTAMP: case NULL: return generateConstant(type); default: @@ -63,8 +136,8 @@ protected Node generateExpression(DatabendDataType type, int } } - public List> generateExpressions(int nr, DatabendDataType type) { - List> expressions = new ArrayList<>(); + public List generateExpressions(int nr, DatabendDataType type) { + List expressions = new ArrayList<>(); for (int i = 0; i < nr; i++) { expressions.add(generateExpression(type)); } @@ -75,7 +148,7 @@ private enum IntExpression { UNARY_OPERATION, BINARY_ARITHMETIC_OPERATION } - private Node generateIntExpression(int depth) { + private DatabendExpression generateIntExpression(int depth) { if (allowAggregateFunctions) { allowAggregateFunctions = false; } @@ -99,11 +172,20 @@ private enum BooleanExpression { // SIMILAR_TO, POSIX_REGEX, BINARY_RANGE_COMPARISON,FUNCTION, CAST,; } - Node generateBooleanExpression(int depth) { + DatabendExpression generateBooleanExpression(int depth) { if (allowAggregateFunctions) { allowAggregateFunctions = false; } List validOptions = new ArrayList<>(Arrays.asList(BooleanExpression.values())); + if (DatabendBugs.bug15570) { + validOptions.remove(BooleanExpression.LIKE); + validOptions.remove(BooleanExpression.IN_OPERATION); + validOptions.remove(BooleanExpression.BETWEEN); + validOptions.remove(BooleanExpression.BINARY_COMPARISON); + } + if (DatabendBugs.bug15572) { + validOptions.remove(BooleanExpression.NOT); + } BooleanExpression option = Randomly.fromList(validOptions); switch (option) { case POSTFIX_OPERATOR: @@ -126,40 +208,39 @@ Node generateBooleanExpression(int depth) { } - Node getPostfix(int depth) { + DatabendExpression getPostfix(int depth) { DatabendUnaryPostfixOperator randomOp = DatabendUnaryPostfixOperator.getRandom(); return new DatabendUnaryPostfixOperation( - generateExpression(Randomly.fromOptions(randomOp.getInputDataTypes()), depth), randomOp, - Randomly.getBoolean()); + generateExpression(Randomly.fromOptions(randomOp.getInputDataTypes()), depth), randomOp); } - Node getNOT(int depth) { + DatabendExpression getNOT(int depth) { DatabendUnaryPrefixOperator op = DatabendUnaryPrefixOperator.NOT; return new DatabendUnaryPrefixOperation(generateExpression(op.getRandomInputDataTypes(), depth), op); } - Node getBetween(int depth) { + DatabendExpression getBetween(int depth) { // 跳过boolean DatabendDataType dataType = Randomly.fromList(Arrays.asList(DatabendDataType.values()).stream() .filter(t -> t != DatabendDataType.BOOLEAN).collect(Collectors.toList())); - return new NewBetweenOperatorNode(generateExpression(dataType, depth), - generateExpression(dataType, depth), generateExpression(dataType, depth), Randomly.getBoolean()); + return new DatabendBetweenOperation(generateExpression(dataType, depth), generateExpression(dataType, depth), + generateExpression(dataType, depth), Randomly.getBoolean()); } - Node getIn(int depth) { + DatabendExpression getIn(int depth) { DatabendDataType dataType = Randomly.fromOptions(DatabendDataType.values()); - Node leftExpr = generateExpression(dataType, depth); - List> rightExprs = new ArrayList<>(); + DatabendExpression leftExpr = generateExpression(dataType, depth); + List rightExprs = new ArrayList<>(); int nr = Randomly.smallNumber() + 1; for (int i = 0; i < nr; i++) { rightExprs.add(generateExpression(dataType, depth)); } - return new NewInOperatorNode(leftExpr, rightExprs, Randomly.getBoolean()); + return new DatabendInOperation(leftExpr, rightExprs, Randomly.getBoolean()); } - Node getBinaryLogical(int depth, DatabendDataType dataType) { - Node expr = generateExpression(dataType, depth); + DatabendExpression getBinaryLogical(int depth, DatabendDataType dataType) { + DatabendExpression expr = generateExpression(dataType, depth); int nr = Randomly.smallNumber() + 1; for (int i = 0; i < nr; i++) { expr = new DatabendBinaryLogicalOperation(expr, generateExpression(DatabendDataType.BOOLEAN, depth), @@ -168,37 +249,48 @@ Node getBinaryLogical(int depth, DatabendDataType dataType) return expr; } - Node getComparison(int depth) { + DatabendExpression getComparison(int depth) { // 跳过boolean DatabendDataType dataType = Randomly.fromList(Arrays.asList(DatabendDataType.values()).stream() .filter(t -> t != DatabendDataType.BOOLEAN).collect(Collectors.toList())); - Node leftExpr = generateExpression(dataType, depth); - Node rightExpr = generateExpression(dataType, depth); + DatabendExpression leftExpr = generateExpression(dataType, depth); + DatabendExpression rightExpr = generateExpression(dataType, depth); return new DatabendBinaryComparisonOperation(leftExpr, rightExpr, Randomly.fromOptions(DatabendBinaryComparisonOperator.values())); } - Node getLike(int depth, DatabendDataType dataType) { + DatabendExpression getLike(int depth, DatabendDataType dataType) { return new DatabendLikeOperation(generateExpression(dataType, depth), generateExpression(dataType, depth), DatabendLikeOperation.DatabendLikeOperator.LIKE_OPERATOR); } + public DatabendExpression generateExpressionWithExpectedResult(DatabendDataType type) { + // DatabendNewExpressionGenerator gen = new + // DatabendNewExpressionGenerator(globalState).setColumns(columns); + // gen.setRowValue(rowValue); + DatabendExpression expr; + do { + expr = this.generateExpression(type); + } while (expr.getExpectedValue() == null); + return expr; + } + @Override - public Node generatePredicate() { + public DatabendExpression generatePredicate() { return generateExpression(DatabendDataType.BOOLEAN); } @Override - public Node negatePredicate(Node predicate) { + public DatabendExpression negatePredicate(DatabendExpression predicate) { return new DatabendUnaryPrefixOperation(predicate, DatabendUnaryPrefixOperator.NOT); } @Override - public Node isNull(Node predicate) { + public DatabendExpression isNull(DatabendExpression predicate) { return new DatabendUnaryPostfixOperation(predicate, DatabendUnaryPostfixOperator.IS_NULL); } - public Node generateConstant(DatabendDataType type, boolean isNullable) { + public DatabendExpression generateConstant(DatabendDataType type, boolean isNullable) { if (isNullable && Randomly.getBooleanWithSmallProbability()) { createConstant(DatabendDataType.NULL); } @@ -206,14 +298,14 @@ public Node generateConstant(DatabendDataType type, boolean } @Override - public Node generateConstant(DatabendDataType type) { + public DatabendExpression generateConstant(DatabendDataType type) { if (Randomly.getBooleanWithSmallProbability()) { return DatabendConstant.createNullConstant(); } return createConstant(type); } - public Node createConstant(DatabendDataType type) { + public DatabendExpression createConstant(DatabendDataType type) { Randomly r = globalState.getRandomly(); switch (type) { case INT: @@ -228,13 +320,17 @@ public Node createConstant(DatabendDataType type) { return DatabendConstant.createStringConstant(r.getString()); case NULL: return DatabendConstant.createNullConstant(); + case DATE: + return DatabendConstant.createDateConstant(r.getInteger()); + case TIMESTAMP: + return DatabendConstant.createTimestampConstant(r.getInteger()); default: throw new AssertionError(type); } } @Override - protected Node generateColumn(DatabendDataType type) { + protected DatabendExpression generateColumn(DatabendDataType type) { return null; } @@ -248,54 +344,95 @@ protected boolean canGenerateColumnOfType(DatabendDataType type) { return false; } - public enum DatabendAggregateFunction { - MAX(1), MIN(1), AVG(1, DatabendDataType.INT, DatabendDataType.FLOAT), COUNT(1), - SUM(1, DatabendDataType.INT, DatabendDataType.FLOAT), STDDEV_POP(1), COVAR_POP(1), COVAR_SAMP(2); - // , STRING_AGG(1), STDDEV_SAMP(1),VAR_SAMP(1), VAR_POP(1) + public DatabendExpression generateArgsForAggregate(DatabendAggregateFunction aggregateFunction) { + return new DatabendAggregateOperation( + generateExpressions(aggregateFunction.getNrArgs(), aggregateFunction.getRandomType()), + aggregateFunction); + } - private int nrArgs; - private DatabendDataType[] dataTypes; + public DatabendExpression generateAggregate() { + DatabendAggregateFunction aggrFunc = DatabendAggregateFunction.getRandom(); + return generateArgsForAggregate(aggrFunc); + } - DatabendAggregateFunction(int nrArgs, DatabendDataType... dataTypes) { - this.nrArgs = nrArgs; - this.dataTypes = dataTypes.clone(); - } + public DatabendExpression generateHavingClause() { + allowAggregateFunctions = true; + DatabendExpression expression = generateExpression(DatabendDataType.BOOLEAN); + allowAggregateFunctions = false; + return expression; + } - public static DatabendAggregateFunction getRandom() { - return Randomly.fromOptions(values()); - } + @Override + public DatabendNewExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); - public DatabendDataType getRandomType() { - if (dataTypes.length == 0) { - return Randomly.fromOptions(DatabendDataType.values()); - } else { - return Randomly.fromOptions(dataTypes); - } - } + return this; + } - public int getNrArgs() { - return nrArgs; - } + @Override + public DatabendExpression generateBooleanExpression() { + return generateExpression(DatabendDataType.BOOLEAN); + } + @Override + public DatabendSelect generateSelect() { + return new DatabendSelect(); } - public NewFunctionNode generateArgsForAggregate( - DatabendAggregateFunction aggregateFunction) { - return new NewFunctionNode( - generateExpressions(aggregateFunction.getNrArgs(), aggregateFunction.getRandomType()), - aggregateFunction); + @Override + public List getRandomJoinClauses() { + List tableList = tables.stream().map(t -> new DatabendTableReference(t)) + .collect(Collectors.toList()); + List joins = DatabendJoin.getJoins(tableList, globalState); + tables = tableList.stream().map(t -> t.getTable()).collect(Collectors.toList()); + return joins; } - public Node generateAggregate() { - DatabendAggregateFunction aggrFunc = DatabendAggregateFunction.getRandom(); - return generateArgsForAggregate(aggrFunc); + @Override + public List getTableRefs() { + return tables.stream().map(t -> new DatabendTableReference(t)).collect(Collectors.toList()); } - public Node generateHavingClause() { - allowAggregateFunctions = true; - Node expression = generateExpression(DatabendDataType.BOOLEAN); - allowAggregateFunctions = false; - return expression; + @Override + public String generateOptimizedQueryString(DatabendSelect select, DatabendExpression whereCondition, + boolean shouldUseAggregate) { + if (shouldUseAggregate) { + DatabendExpression aggr = new DatabendAggregateOperation( + List.of(new DatabendColumnReference(new DatabendColumn("*", + new DatabendCompositeDataType(DatabendDataType.INT, 0), false, false))), + DatabendAggregateFunction.COUNT); + select.setFetchColumns(List.of(aggr)); + } else { + List allColumns = columns.stream().map((c) -> new DatabendColumnReference(c)) + .collect(Collectors.toList()); + select.setFetchColumns(allColumns); + if (Randomly.getBooleanWithSmallProbability()) { + select.setOrderByClauses(generateOrderBys()); + } + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(DatabendSelect select, DatabendExpression whereCondition) { + DatabendExpression asText = new DatabendPostFixText(new DatabendCastOperation( + new DatabendPostFixText(whereCondition, + " IS NOT NULL AND " + DatabendToStringVisitor.asString(whereCondition)), + new DatabendCompositeDataType(DatabendDataType.INT, 8)), "as count"); + select.setFetchColumns(List.of(asText)); + select.setWhereClause(null); + + return "SELECT SUM(count) FROM (" + select.asString() + ") as res"; } + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy) { + return List.of(new DatabendColumnReference(new DatabendColumn("*", null, false, false))); + } + return columns.stream().map(c -> new DatabendColumnReference(c)).collect(Collectors.toList()); + } } diff --git a/src/sqlancer/databend/gen/DatabendRandomQuerySynthesizer.java b/src/sqlancer/databend/gen/DatabendRandomQuerySynthesizer.java index 286764c6e..875c5afee 100644 --- a/src/sqlancer/databend/gen/DatabendRandomQuerySynthesizer.java +++ b/src/sqlancer/databend/gen/DatabendRandomQuerySynthesizer.java @@ -1,20 +1,22 @@ package sqlancer.databend.gen; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.stream.Collectors; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; import sqlancer.databend.DatabendProvider.DatabendGlobalState; import sqlancer.databend.DatabendSchema; +import sqlancer.databend.DatabendSchema.DatabendColumn; import sqlancer.databend.DatabendSchema.DatabendTable; import sqlancer.databend.DatabendSchema.DatabendTables; +import sqlancer.databend.ast.DatabendColumnValue; import sqlancer.databend.ast.DatabendConstant; import sqlancer.databend.ast.DatabendExpression; import sqlancer.databend.ast.DatabendJoin; import sqlancer.databend.ast.DatabendSelect; +import sqlancer.databend.ast.DatabendTableReference; public final class DatabendRandomQuerySynthesizer { @@ -22,39 +24,58 @@ private DatabendRandomQuerySynthesizer() { } public static DatabendSelect generateSelect(DatabendGlobalState globalState, int nrColumns) { - DatabendTables targetTables = globalState.getSchema().getRandomTableNonEmptyTables(); - DatabendNewExpressionGenerator gen = new DatabendNewExpressionGenerator(globalState) - .setColumns(targetTables.getColumns()); - DatabendSelect select = new DatabendSelect(); - // TODO distinct - // select.setDistinct(Randomly.getBoolean()); + DatabendTables targetTables = globalState.getSchema().getRandomTableNonEmptyAndViewTables(); + List targetColumns = targetTables.getColumns(); + DatabendNewExpressionGenerator gen = new DatabendNewExpressionGenerator(globalState).setColumns(targetColumns); // boolean allowAggregates = Randomly.getBooleanWithSmallProbability(); - List> columns = new ArrayList<>(); + List columns = new ArrayList<>(); + HashSet columnOfLeafNode = new HashSet<>(); + gen.setColumnOfLeafNode(columnOfLeafNode); + int freeColumns = targetColumns.size(); for (int i = 0; i < nrColumns; i++) { // if (allowAggregates && Randomly.getBoolean()) { - Node expression = gen.generateExpression(DatabendSchema.DatabendDataType.BOOLEAN); - columns.add(expression); - // } else { - // columns.add(gen()); - // } + DatabendExpression column = null; + if (freeColumns > 0 && Randomly.getBoolean()) { + column = new DatabendColumnValue(targetColumns.get(freeColumns - 1), null); + freeColumns -= 1; + columnOfLeafNode.add((DatabendColumnValue) column); + } else { + column = gen.generateExpression(DatabendSchema.DatabendDataType.BOOLEAN); + } + columns.add(column); } + DatabendSelect select = new DatabendSelect(); + boolean isDistinct = Randomly.getBoolean(); + select.setDistinct(isDistinct); select.setFetchColumns(columns); List tables = targetTables.getTables(); - List> tableList = tables.stream() - .map(t -> new TableReferenceNode(t)).collect(Collectors.toList()); - List> joins = DatabendJoin.getJoins(tableList, globalState); + List tableList = tables.stream().map(t -> new DatabendTableReference(t)) + .collect(Collectors.toList()); + List joins = DatabendJoin.getJoins(tableList, globalState); select.setJoinList(joins.stream().collect(Collectors.toList())); select.setFromList(tableList.stream().collect(Collectors.toList())); if (Randomly.getBoolean()) { select.setWhereClause(gen.generateExpression(DatabendSchema.DatabendDataType.BOOLEAN)); } - // if (Randomly.getBoolean()) {//TODO order by超过实际行数 - // select.setOrderByExpressions(gen.generateOrderBys()); - // } - // if (Randomly.getBoolean()) { //TODO group by超过实际行数 - // select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); - // } + List noExprColumns = new ArrayList<>(columnOfLeafNode); + + if (Randomly.getBoolean() && !noExprColumns.isEmpty() && !isDistinct) { + select.setOrderByClauses(Randomly.nonEmptySubset(noExprColumns)); + // TODO (for SELECT DISTINCT, ORDER BY expressions must appear in select list) + // isDistinct + // 需要orderby输入每个select list,可以用数字代替比如:1,2,3... + } + + if (Randomly.getBoolean()) { // 可能产生新的column叶子结点 + select.setHavingClause(gen.generateHavingClause()); + } + + noExprColumns = new ArrayList<>(columnOfLeafNode); + + if (Randomly.getBoolean() && !noExprColumns.isEmpty()) { + select.setGroupByExpressions(noExprColumns); + } if (Randomly.getBoolean()) { select.setLimitClause( @@ -64,10 +85,7 @@ public static DatabendSelect generateSelect(DatabendGlobalState globalState, int select.setOffsetClause( DatabendConstant.createIntConstant(Randomly.getNotCachedInteger(0, Integer.MAX_VALUE))); } - // TODO 待添加HavingClause - // if (Randomly.getBoolean()) { - // select.setHavingClause(gen.generateHavingClause()); - // } + return select; } diff --git a/src/sqlancer/databend/gen/DatabendTableGenerator.java b/src/sqlancer/databend/gen/DatabendTableGenerator.java index c83cd93e9..514d740d0 100644 --- a/src/sqlancer/databend/gen/DatabendTableGenerator.java +++ b/src/sqlancer/databend/gen/DatabendTableGenerator.java @@ -4,10 +4,10 @@ import java.util.List; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; import sqlancer.common.gen.TypedExpressionGenerator; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.databend.DatabendErrors; import sqlancer.databend.DatabendProvider.DatabendGlobalState; import sqlancer.databend.DatabendSchema.DatabendColumn; import sqlancer.databend.DatabendSchema.DatabendCompositeDataType; @@ -19,13 +19,14 @@ public class DatabendTableGenerator { public SQLQueryAdapter getQuery(DatabendGlobalState globalState) { ExpectedErrors errors = new ExpectedErrors(); + DatabendErrors.addExpressionErrors(errors); StringBuilder sb = new StringBuilder(); String tableName = globalState.getSchema().getFreeTableName(); sb.append("CREATE TABLE "); sb.append(tableName); sb.append("("); List columns = getNewColumns(); - TypedExpressionGenerator, DatabendColumn, DatabendDataType> gen = new DatabendNewExpressionGenerator( + TypedExpressionGenerator gen = new DatabendNewExpressionGenerator( globalState).setColumns(columns); for (int i = 0; i < columns.size(); i++) { if (i != 0) { @@ -34,27 +35,14 @@ public SQLQueryAdapter getQuery(DatabendGlobalState globalState) { sb.append(columns.get(i).getName()); sb.append(" "); sb.append(columns.get(i).getType()); - // if (globalState.getDbmsSpecificOptions().testCollate && Randomly.getBooleanWithRatherLowProbability() - // && columns.get(i).getType().getPrimitiveDataType() == DatabendDataType.VARCHAR) { - // sb.append(" COLLATE "); - // sb.append(getRandomCollate()); - // } - // if (globalState.getDbmsSpecificOptions().testIndexes && Randomly.getBooleanWithRatherLowProbability()) { - // sb.append(" UNIQUE"); - // } + if (globalState.getDbmsSpecificOptions().testNotNullConstraints && Randomly.getBooleanWithRatherLowProbability()) { sb.append(" NOT NULL"); } else { sb.append(" NULL"); // Databend 默认字段为非空,这个将它默认设置为允许空 } - // if (globalState.getDbmsSpecificOptions().testCheckConstraints //databend 无check约束 - // && Randomly.getBooleanWithRatherLowProbability()) { - // sb.append(" CHECK("); - // sb.append(DatabendToStringVisitor.asString(gen.generateExpression())); - // DatabendErrors.addExpressionErrors(errors); - // sb.append(")"); - // } + if (Randomly.getBoolean() && globalState.getDbmsSpecificOptions().testDefaultValues) { sb.append(" DEFAULT("); sb.append(DatabendToStringVisitor.asString(// 常量类型于字段类型等同 @@ -62,22 +50,11 @@ public SQLQueryAdapter getQuery(DatabendGlobalState globalState) { sb.append(")"); } } - // databend并没有索引 - // if (globalState.getDbmsSpecificOptions().testIndexes && Randomly.getBoolean()) { - // errors.add("Invalid type for index"); - // List primaryKeyColumns = Randomly.nonEmptySubset(columns); - // sb.append(", PRIMARY KEY("); - // sb.append(primaryKeyColumns.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); - // sb.append(")"); - // } + sb.append(")"); return new SQLQueryAdapter(sb.toString(), errors, true); } - public static String getRandomCollate() { - return Randomly.fromOptions("NOCASE", "NOACCENT", "NOACCENT.NOCASE", "C", "POSIX"); - } - private static List getNewColumns() { List columns = new ArrayList<>(); for (int i = 0; i < Randomly.smallNumber() + 1; i++) { diff --git a/src/sqlancer/databend/gen/DatabendViewGenerator.java b/src/sqlancer/databend/gen/DatabendViewGenerator.java index 55cc6ffc5..75ad57239 100644 --- a/src/sqlancer/databend/gen/DatabendViewGenerator.java +++ b/src/sqlancer/databend/gen/DatabendViewGenerator.java @@ -17,15 +17,6 @@ public static SQLQueryAdapter generate(DatabendGlobalState globalState) { StringBuilder sb = new StringBuilder("CREATE "); sb.append("VIEW "); sb.append(globalState.getSchema().getFreeViewName()); - // sb.append("("); - // for (int i = 0; i < nrColumns; i++) { - // if (i != 0) { - // sb.append(", "); - // } - // sb.append("c"); - // sb.append(i); - // } - // sb.append(") AS "); sb.append(" AS "); sb.append(DatabendToStringVisitor .asString(DatabendRandomQuerySynthesizer.generateSelect(globalState, nrColumns))); diff --git a/src/sqlancer/databend/test/DatabendNoRECOracle.java b/src/sqlancer/databend/test/DatabendNoRECOracle.java deleted file mode 100644 index cb27fa78b..000000000 --- a/src/sqlancer/databend/test/DatabendNoRECOracle.java +++ /dev/null @@ -1,140 +0,0 @@ -package sqlancer.databend.test; - -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import sqlancer.IgnoreMeException; -import sqlancer.Randomly; -import sqlancer.SQLConnection; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.NewPostfixTextNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; -import sqlancer.common.oracle.NoRECBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.common.query.SQLQueryAdapter; -import sqlancer.common.query.SQLancerResultSet; -import sqlancer.databend.DatabendErrors; -import sqlancer.databend.DatabendProvider.DatabendGlobalState; -import sqlancer.databend.DatabendSchema; -import sqlancer.databend.DatabendSchema.DatabendColumn; -import sqlancer.databend.DatabendSchema.DatabendCompositeDataType; -import sqlancer.databend.DatabendSchema.DatabendDataType; -import sqlancer.databend.DatabendSchema.DatabendTable; -import sqlancer.databend.DatabendSchema.DatabendTables; -import sqlancer.databend.DatabendToStringVisitor; -import sqlancer.databend.ast.DatabendCastOperation; -import sqlancer.databend.ast.DatabendExpression; -import sqlancer.databend.ast.DatabendJoin; -import sqlancer.databend.ast.DatabendSelect; -import sqlancer.databend.gen.DatabendNewExpressionGenerator; - -public class DatabendNoRECOracle extends NoRECBase implements TestOracle { - - private final DatabendSchema s; - - public DatabendNoRECOracle(DatabendGlobalState globalState) { - super(globalState); - this.s = globalState.getSchema(); - DatabendErrors.addExpressionErrors(errors); - } - - @Override - public void check() throws SQLException { - DatabendTables randomTables = s.getRandomTableNonEmptyTables(); // 随机获得nr张表 - List columns = randomTables.getColumns(); - DatabendNewExpressionGenerator gen = new DatabendNewExpressionGenerator(state).setColumns(columns); - - Node randomWhereCondition = gen.generateExpression(DatabendDataType.BOOLEAN); // 生成随机where条件,形式为ast - - List tables = randomTables.getTables(); - List> tableList = tables.stream() - .map(t -> new TableReferenceNode(t)).collect(Collectors.toList()); - List> joins = DatabendJoin.getJoins(tableList, state); - int secondCount = getSecondQuery(tableList.stream().collect(Collectors.toList()), randomWhereCondition, joins); // 禁用优化 - int firstCount = getFirstQueryCount(con, tableList.stream().collect(Collectors.toList()), columns, - randomWhereCondition, joins); - if (firstCount == -1 || secondCount == -1) { - throw new IgnoreMeException(); - } - if (firstCount != secondCount) { - throw new AssertionError( - optimizedQueryString + "; -- " + firstCount + "\n" + unoptimizedQueryString + " -- " + secondCount); - } - } - - private int getSecondQuery(List> tableList, Node randomWhereCondition, - List> joins) throws SQLException { - DatabendSelect select = new DatabendSelect(); - // select.setGroupByClause(groupBys); - // DatabendExpression isTrue = DatabendPostfixOperation.create(randomWhereCondition, - // PostfixOperator.IS_TRUE); - Node asText = new NewPostfixTextNode<>(new DatabendCastOperation( - new NewPostfixTextNode(randomWhereCondition, - " IS NOT NULL AND " + DatabendToStringVisitor.asString(randomWhereCondition)), - new DatabendCompositeDataType(DatabendDataType.INT, 8)), "as count"); - - select.setFetchColumns(Arrays.asList(asText)); // ? - select.setFromList(tableList); - select.setJoinList(joins); - int secondCount = 0; - unoptimizedQueryString = "SELECT SUM(count) FROM (" + DatabendToStringVisitor.asString(select) + ") as res"; - errors.add("canceling statement due to statement timeout"); - SQLQueryAdapter q = new SQLQueryAdapter(unoptimizedQueryString, errors); - SQLancerResultSet rs; - try { - rs = q.executeAndGetLogged(state); - } catch (Exception e) { - throw new AssertionError(unoptimizedQueryString, e); - } - if (rs == null) { - return -1; - } - if (rs.next()) { - secondCount += rs.getLong(1); - } - rs.close(); - return secondCount; - } - - private int getFirstQueryCount(SQLConnection con, List> tableList, - List columns, Node randomWhereCondition, - List> joins) throws SQLException { - DatabendSelect select = new DatabendSelect(); - // select.setGroupByClause(groupBys); - // DatabendAggregate aggr = new DatabendAggregate( - List> allColumns = columns.stream() - .map((c) -> new ColumnReferenceNode(c)) - .collect(Collectors.toList()); - // DatabendAggregateFunction.COUNT); - // select.setFetchColumns(Arrays.asList(aggr)); - select.setFetchColumns(allColumns); - select.setFromList(tableList); - select.setWhereClause(randomWhereCondition); - if (Randomly.getBooleanWithSmallProbability()) { - select.setOrderByExpressions( - new DatabendNewExpressionGenerator(state).setColumns(columns).generateOrderBys()); - } - select.setJoinList(joins); - int firstCount = 0; - try (Statement stat = con.createStatement()) { - optimizedQueryString = DatabendToStringVisitor.asString(select); - if (options.logEachSelect()) { - logger.writeCurrent(optimizedQueryString); - } - try (ResultSet rs = stat.executeQuery(optimizedQueryString)) { - while (rs.next()) { - firstCount++; - } - } - } catch (SQLException e) { - throw new IgnoreMeException(); - } - return firstCount; - } - -} diff --git a/src/sqlancer/databend/test/DatabendPivotedQuerySynthesisOracle.java b/src/sqlancer/databend/test/DatabendPivotedQuerySynthesisOracle.java new file mode 100644 index 000000000..4a72454b5 --- /dev/null +++ b/src/sqlancer/databend/test/DatabendPivotedQuerySynthesisOracle.java @@ -0,0 +1,150 @@ +package sqlancer.databend.test; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.oracle.PivotedQuerySynthesisBase; +import sqlancer.common.query.Query; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.databend.DatabendErrors; +import sqlancer.databend.DatabendExpectedValueVisitor; +import sqlancer.databend.DatabendProvider.DatabendGlobalState; +import sqlancer.databend.DatabendSchema.DatabendColumn; +import sqlancer.databend.DatabendSchema.DatabendDataType; +import sqlancer.databend.DatabendSchema.DatabendRowValue; +import sqlancer.databend.DatabendSchema.DatabendTables; +import sqlancer.databend.DatabendToStringVisitor; +import sqlancer.databend.ast.DatabendColumnValue; +import sqlancer.databend.ast.DatabendConstant; +import sqlancer.databend.ast.DatabendExpression; +import sqlancer.databend.ast.DatabendSelect; +import sqlancer.databend.ast.DatabendTableReference; +import sqlancer.databend.ast.DatabendUnaryPostfixOperation; +import sqlancer.databend.ast.DatabendUnaryPrefixOperation; +import sqlancer.databend.gen.DatabendNewExpressionGenerator; + +public class DatabendPivotedQuerySynthesisOracle + extends PivotedQuerySynthesisBase { + + private List fetchColumns; + + public DatabendPivotedQuerySynthesisOracle(DatabendGlobalState globalState) { + super(globalState); + DatabendErrors.addExpressionErrors(errors); + DatabendErrors.addInsertErrors(errors); + } + + @Override + protected Query getRectifiedQuery() throws Exception { + DatabendTables randomTables = globalState.getSchema().getRandomTableNonEmptyAndViewTables(); + List columns = randomTables.getColumns(); + DatabendSelect selectStatement = new DatabendSelect(); + boolean isDistinct = Randomly.getBoolean(); + selectStatement.setDistinct(isDistinct); + pivotRow = randomTables.getRandomRowValue(globalState.getConnection()); + fetchColumns = columns; + selectStatement.setFetchColumns(fetchColumns.stream() + .map(c -> new DatabendColumnValue(getFetchValueAliasedColumn(c), pivotRow.getValues().get(c))) + .collect(Collectors.toList())); + selectStatement.setFromList( + randomTables.getTables().stream().map(t -> new DatabendTableReference(t)).collect(Collectors.toList())); + DatabendExpression whereClause = generateRectifiedExpression(columns, pivotRow); + selectStatement.setWhereClause(whereClause); + List groupByClause = generateGroupByClause(columns, pivotRow); + selectStatement.setGroupByExpressions(groupByClause); + DatabendExpression limitClause = generateLimit(); + selectStatement.setLimitClause(limitClause); + if (limitClause != null) { + DatabendExpression offsetClause = generateOffset(); + selectStatement.setOffsetClause(offsetClause); + } + DatabendNewExpressionGenerator gen = new DatabendNewExpressionGenerator(globalState).setColumns(columns); + if (!isDistinct) { + List orderBys = gen.generateOrderBy(); + selectStatement.setOrderByClauses(orderBys); + } + return new SQLQueryAdapter(DatabendToStringVisitor.asString(selectStatement), errors); + } + + private DatabendExpression generateRectifiedExpression(List columns, DatabendRowValue pivotRow) { + DatabendNewExpressionGenerator gen = new DatabendNewExpressionGenerator(globalState).setColumns(columns); + gen.setRowValue(pivotRow); + DatabendExpression expr = gen.generateExpressionWithExpectedResult(DatabendDataType.BOOLEAN); + DatabendExpression result = null; + if (expr.getExpectedValue().isNull()) { + result = new DatabendUnaryPostfixOperation(expr, + DatabendUnaryPostfixOperation.DatabendUnaryPostfixOperator.IS_NULL); + } else if (!expr.getExpectedValue().cast(DatabendDataType.BOOLEAN).asBoolean()) { + result = new DatabendUnaryPrefixOperation(expr, + DatabendUnaryPrefixOperation.DatabendUnaryPrefixOperator.NOT); + } + rectifiedPredicates.add(result); + return result; + } + + @Override + protected Query getContainmentCheckQuery(Query pivotRowQuery) throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append("SELECT * FROM ("); + sb.append(pivotRowQuery.getUnterminatedQueryString()); + sb.append(") as result WHERE "); + int i = 0; + for (DatabendColumn c : fetchColumns) { + if (i++ != 0) { + sb.append(" AND "); + } + sb.append("result."); + sb.append(c.getTable().getName()); + sb.append(c.getName()); + if (pivotRow.getValues().get(c).isNull()) { + sb.append(" IS NULL "); + } else { + sb.append(" = "); + sb.append(pivotRow.getValues().get(c).toString()); + } + } + String resultingQueryString = sb.toString(); + return new SQLQueryAdapter(resultingQueryString, errors); + } + + private DatabendColumn getFetchValueAliasedColumn(DatabendColumn c) { + DatabendColumn aliasedColumn = new DatabendColumn(c.getName() + " AS " + c.getTable().getName() + c.getName(), + c.getType(), false, false); + aliasedColumn.setTable(c.getTable()); + return aliasedColumn; + } + + @Override + protected String getExpectedValues(DatabendExpression expr) { + return DatabendExpectedValueVisitor.asExpectedValues(expr); + } + + private List generateGroupByClause(List columns, DatabendRowValue rowValue) { + if (Randomly.getBoolean()) { + return columns.stream().map(c -> new DatabendColumnValue(c, rowValue.getValues().get(c))) + .collect(Collectors.toList()); + } else { + return Collections.emptyList(); + } + } + + private DatabendExpression generateLimit() { + if (Randomly.getBoolean()) { + return DatabendConstant.createIntConstant(Integer.MAX_VALUE); + } else { + return null; + } + } + + private DatabendExpression generateOffset() { + if (Randomly.getBoolean()) { + return DatabendConstant.createIntConstant(0); + } else { + return null; + } + } + +} diff --git a/src/sqlancer/databend/test/DatabendQueryPartitioningBase.java b/src/sqlancer/databend/test/DatabendQueryPartitioningBase.java deleted file mode 100644 index 6ef132572..000000000 --- a/src/sqlancer/databend/test/DatabendQueryPartitioningBase.java +++ /dev/null @@ -1,88 +0,0 @@ -package sqlancer.databend.test; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; - -import sqlancer.Randomly; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; -import sqlancer.common.gen.ExpressionGenerator; -import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.databend.DatabendErrors; -import sqlancer.databend.DatabendProvider.DatabendGlobalState; -import sqlancer.databend.DatabendSchema; -import sqlancer.databend.DatabendSchema.DatabendColumn; -import sqlancer.databend.DatabendSchema.DatabendTable; -import sqlancer.databend.DatabendSchema.DatabendTables; -import sqlancer.databend.ast.DatabendExpression; -import sqlancer.databend.ast.DatabendJoin; -import sqlancer.databend.ast.DatabendSelect; -import sqlancer.databend.gen.DatabendNewExpressionGenerator; - -public class DatabendQueryPartitioningBase extends - TernaryLogicPartitioningOracleBase, DatabendGlobalState> implements TestOracle { - - DatabendSchema s; - DatabendTables targetTables; - DatabendNewExpressionGenerator gen; - DatabendSelect select; - - public DatabendQueryPartitioningBase(DatabendGlobalState state) { - super(state); - DatabendErrors.addExpressionErrors(errors); - } - - public static String canonicalizeResultValue(String value) { - // Rule: -0.0 should be canonicalized to 0.0 - if (Objects.equals(value, "-0.0")) { - return "0.0"; - } - - return value; - } - - @Override - public void check() throws SQLException { - s = state.getSchema(); - targetTables = s.getRandomTableNonEmptyTables(); - gen = new DatabendNewExpressionGenerator(state).setColumns(targetTables.getColumns()); - initializeTernaryPredicateVariants(); - select = new DatabendSelect(); - select.setFetchColumns(generateRandomColumns()); - List tables = targetTables.getTables(); - List> tableList = tables.stream() - .map(t -> new TableReferenceNode(t)).collect(Collectors.toList()); - List> joins = DatabendJoin.getJoins(tableList, state); - select.setJoinList(joins.stream().collect(Collectors.toList())); - select.setFromList(tableList.stream().collect(Collectors.toList())); - select.setWhereClause(null); - } - - List> generateFetchColumns() { - List> columns = new ArrayList<>(); - if (Randomly.getBoolean()) { // TODO 为什么会返回 false 或 true 字段 - columns.add(new ColumnReferenceNode<>(new DatabendColumn("*", null, false, false))); - } else { - columns = generateRandomColumns(); - } - return columns; - } - - List> generateRandomColumns() { - List> columns; - columns = Randomly.nonEmptySubset(targetTables.getColumns()).stream() - .map(c -> new ColumnReferenceNode(c)).collect(Collectors.toList()); - return columns; - } - - @Override - protected ExpressionGenerator> getGen() { - return gen; - } - -} diff --git a/src/sqlancer/databend/test/DatabendQueryPartitioningWhereTester.java b/src/sqlancer/databend/test/DatabendQueryPartitioningWhereTester.java deleted file mode 100644 index 2b43d0ba0..000000000 --- a/src/sqlancer/databend/test/DatabendQueryPartitioningWhereTester.java +++ /dev/null @@ -1,45 +0,0 @@ -package sqlancer.databend.test; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.databend.DatabendErrors; -import sqlancer.databend.DatabendProvider.DatabendGlobalState; -import sqlancer.databend.DatabendToStringVisitor; - -public class DatabendQueryPartitioningWhereTester extends DatabendQueryPartitioningBase { - - public DatabendQueryPartitioningWhereTester(DatabendGlobalState state) { - super(state); - DatabendErrors.addGroupByErrors(errors); - } - - @Override - public void check() throws SQLException { - super.check(); - select.setWhereClause(null); - String originalQueryString = DatabendToStringVisitor.asString(select); - - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - // boolean orderBy = Randomly.getBooleanWithRatherLowProbability(); - boolean orderBy = false; - // if (orderBy) { //TODO 待开启 - // select.setOrderByExpressions(gen.generateOrderBys()); - // } - select.setWhereClause(predicate); - String firstQueryString = DatabendToStringVisitor.asString(select); - select.setWhereClause(negatedPredicate); - String secondQueryString = DatabendToStringVisitor.asString(select); - select.setWhereClause(isNullPredicate); - String thirdQueryString = DatabendToStringVisitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, !orderBy, state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state, DatabendQueryPartitioningBase::canonicalizeResultValue); - } - -} diff --git a/src/sqlancer/databend/test/DatabendQueryPartitioningAggregateTester.java b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningAggregateTester.java similarity index 63% rename from src/sqlancer/databend/test/DatabendQueryPartitioningAggregateTester.java rename to src/sqlancer/databend/test/tlp/DatabendQueryPartitioningAggregateTester.java index 59d6695ca..6d52caea3 100644 --- a/src/sqlancer/databend/test/DatabendQueryPartitioningAggregateTester.java +++ b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningAggregateTester.java @@ -1,4 +1,4 @@ -package sqlancer.databend.test; +package sqlancer.databend.test.tlp; import java.sql.SQLException; import java.util.ArrayList; @@ -8,29 +8,29 @@ import sqlancer.ComparatorHelper; import sqlancer.IgnoreMeException; import sqlancer.Randomly; -import sqlancer.common.ast.newast.NewAliasNode; -import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; -import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.oracle.TestOracle; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLancerResultSet; +import sqlancer.databend.DatabendBugs; import sqlancer.databend.DatabendErrors; import sqlancer.databend.DatabendProvider.DatabendGlobalState; import sqlancer.databend.DatabendSchema.DatabendCompositeDataType; import sqlancer.databend.DatabendSchema.DatabendDataType; import sqlancer.databend.DatabendToStringVisitor; +import sqlancer.databend.ast.DatabendAggregateOperation; +import sqlancer.databend.ast.DatabendAggregateOperation.DatabendAggregateFunction; +import sqlancer.databend.ast.DatabendAlias; import sqlancer.databend.ast.DatabendBinaryArithmeticOperation.DatabendBinaryArithmeticOperator; +import sqlancer.databend.ast.DatabendBinaryOperation; import sqlancer.databend.ast.DatabendCastOperation; import sqlancer.databend.ast.DatabendExpression; +import sqlancer.databend.ast.DatabendFunctionOperation; import sqlancer.databend.ast.DatabendSelect; +import sqlancer.databend.ast.DatabendUnaryPostfixOperation; import sqlancer.databend.ast.DatabendUnaryPostfixOperation.DatabendUnaryPostfixOperator; +import sqlancer.databend.ast.DatabendUnaryPrefixOperation; import sqlancer.databend.ast.DatabendUnaryPrefixOperation.DatabendUnaryPrefixOperator; -import sqlancer.databend.gen.DatabendNewExpressionGenerator.DatabendAggregateFunction; -public class DatabendQueryPartitioningAggregateTester extends DatabendQueryPartitioningBase implements TestOracle { +public class DatabendQueryPartitioningAggregateTester extends DatabendQueryPartitioningBase { private String firstResult; private String secondResult; @@ -45,19 +45,23 @@ public DatabendQueryPartitioningAggregateTester(DatabendGlobalState state) { @Override public void check() throws SQLException { super.check(); - DatabendAggregateFunction aggregateFunction = Randomly.fromOptions(DatabendAggregateFunction.MAX, - DatabendAggregateFunction.MIN, DatabendAggregateFunction.SUM, DatabendAggregateFunction.COUNT, - DatabendAggregateFunction.AVG/* , DatabendAggregateFunction.STDDEV_POP */); - NewFunctionNode aggregate = gen + List aggregateFunctions = new ArrayList<>(List.of(DatabendAggregateFunction.MAX, + DatabendAggregateFunction.MIN, DatabendAggregateFunction.SUM, DatabendAggregateFunction.COUNT + /* , DatabendAggregateFunction.STDDEV_POP */)); + if (!DatabendBugs.bug19738) { + aggregateFunctions.add(DatabendAggregateFunction.AVG); + } + DatabendAggregateFunction aggregateFunction = Randomly.fromList(aggregateFunctions); + DatabendFunctionOperation aggregate = (DatabendAggregateOperation) gen .generateArgsForAggregate(aggregateFunction); - List> fetchColumns = new ArrayList<>(); + List fetchColumns = new ArrayList<>(); fetchColumns.add(aggregate); while (Randomly.getBooleanWithRatherLowProbability()) { - fetchColumns.add(gen.generateAggregate()); + fetchColumns.add((DatabendAggregateOperation) gen.generateAggregate()); // TODO 更换成非聚合函数 } select.setFetchColumns(Arrays.asList(aggregate)); // if (Randomly.getBooleanWithRatherLowProbability()) { - // select.setOrderByExpressions(gen.generateOrderBys()); + // select.setOrderByClauses(gen.generateOrderBys()); // } originalQuery = DatabendToStringVisitor.asString(select); firstResult = getAggregateResult(originalQuery); @@ -78,15 +82,14 @@ public void check() throws SQLException { } private String createMetamorphicUnionQuery(DatabendSelect select, - NewFunctionNode aggregate, - List> from) { + DatabendFunctionOperation aggregate, List from) { String metamorphicQuery; - Node whereClause = gen.generateExpression(DatabendDataType.BOOLEAN); - Node negatedClause = new NewUnaryPrefixOperatorNode<>(whereClause, + DatabendExpression whereClause = gen.generateExpression(DatabendDataType.BOOLEAN); + DatabendExpression negatedClause = new DatabendUnaryPrefixOperation(whereClause, DatabendUnaryPrefixOperator.NOT); - Node notNullClause = new NewUnaryPostfixOperatorNode<>(whereClause, + DatabendExpression notNullClause = new DatabendUnaryPostfixOperation(whereClause, DatabendUnaryPostfixOperator.IS_NULL); - List> mappedAggregate = mapped(aggregate); + List mappedAggregate = mapped(aggregate); DatabendSelect leftSelect = getSelect(mappedAggregate, from, whereClause, select.getJoinList()); DatabendSelect middleSelect = getSelect(mappedAggregate, from, negatedClause, select.getJoinList()); DatabendSelect rightSelect = getSelect(mappedAggregate, from, notNullClause, select.getJoinList()); @@ -124,8 +127,7 @@ private String getAggregateResult(String queryString) throws SQLException { } } - private List> mapped( - NewFunctionNode aggregate) { + private List mapped(DatabendFunctionOperation aggregate) { DatabendCastOperation count; switch (aggregate.getFunc()) { case COUNT: @@ -134,22 +136,21 @@ private List> mapped( case SUM: return aliasArgs(Arrays.asList(aggregate)); case AVG: - NewFunctionNode sum = new NewFunctionNode<>( + DatabendFunctionOperation sum = new DatabendFunctionOperation<>( aggregate.getArgs(), DatabendAggregateFunction.SUM); count = new DatabendCastOperation( - new NewFunctionNode<>(aggregate.getArgs(), DatabendAggregateFunction.COUNT), + new DatabendFunctionOperation<>(aggregate.getArgs(), DatabendAggregateFunction.COUNT), new DatabendCompositeDataType(DatabendDataType.FLOAT, 8)); return aliasArgs(Arrays.asList(sum, count)); case STDDEV_POP: - NewFunctionNode sumSquared = new NewFunctionNode<>( - Arrays.asList(new NewBinaryOperatorNode<>(aggregate.getArgs().get(0), aggregate.getArgs().get(0), + DatabendFunctionOperation sumSquared = new DatabendFunctionOperation<>( + Arrays.asList(new DatabendBinaryOperation(aggregate.getArgs().get(0), aggregate.getArgs().get(0), DatabendBinaryArithmeticOperator.MULTIPLICATION)), DatabendAggregateFunction.SUM); count = new DatabendCastOperation( - new NewFunctionNode(aggregate.getArgs(), - DatabendAggregateFunction.COUNT), + new DatabendFunctionOperation<>(aggregate.getArgs(), DatabendAggregateFunction.COUNT), new DatabendCompositeDataType(DatabendDataType.FLOAT, 8)); - NewFunctionNode avg = new NewFunctionNode<>( + DatabendFunctionOperation avg = new DatabendFunctionOperation<>( aggregate.getArgs(), DatabendAggregateFunction.AVG); return aliasArgs(Arrays.asList(sumSquared, count, avg)); default: @@ -157,16 +158,16 @@ private List> mapped( } } - private List> aliasArgs(List> originalAggregateArgs) { - List> args = new ArrayList<>(); + private List aliasArgs(List originalAggregateArgs) { + List args = new ArrayList<>(); int i = 0; - for (Node expr : originalAggregateArgs) { - args.add(new NewAliasNode(expr, "agg" + i++)); + for (DatabendExpression expr : originalAggregateArgs) { + args.add(new DatabendAlias(expr, "agg" + i++)); } return args; } - private String getOuterAggregateFunction(NewFunctionNode aggregate) { + private String getOuterAggregateFunction(DatabendFunctionOperation aggregate) { switch (aggregate.getFunc()) { case STDDEV_POP: return "sqrt(SUM(agg0)/SUM(agg1)-SUM(agg2)*SUM(agg2))"; @@ -179,18 +180,18 @@ private String getOuterAggregateFunction(NewFunctionNode> aggregates, List> from, - Node whereClause, List> joinList) { - DatabendSelect leftSelect = new DatabendSelect(); - leftSelect.setFetchColumns(aggregates); - leftSelect.setFromList(from); - leftSelect.setWhereClause(whereClause); - leftSelect.setJoinList(joinList); - // if (Randomly.getBooleanWithSmallProbability()) { - // leftSelect.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); //TODO group by超过实际行数 - // leftSelect.setGroupByExpressions(select.getFetchColumns());// TODO group by不能放入聚合函数 - // } - return leftSelect; + private DatabendSelect getSelect(List aggregates, List from, + DatabendExpression whereClause, List joinList) { + DatabendSelect select = new DatabendSelect(); + select.setFetchColumns(aggregates); + select.setFromList(from); + select.setWhereClause(whereClause); + select.setJoinList(joinList); + if (Randomly.getBooleanWithSmallProbability()) { + select.setGroupByExpressions(List.of(gen.generateConstant(DatabendDataType.INT))); // TODO + // 仍可加强 + } + return select; } } diff --git a/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningBase.java b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningBase.java new file mode 100644 index 000000000..611164902 --- /dev/null +++ b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningBase.java @@ -0,0 +1,94 @@ +package sqlancer.databend.test.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.databend.DatabendBugs; +import sqlancer.databend.DatabendErrors; +import sqlancer.databend.DatabendProvider.DatabendGlobalState; +import sqlancer.databend.DatabendSchema; +import sqlancer.databend.DatabendSchema.DatabendColumn; +import sqlancer.databend.DatabendSchema.DatabendTable; +import sqlancer.databend.DatabendSchema.DatabendTables; +import sqlancer.databend.ast.DatabendColumnReference; +import sqlancer.databend.ast.DatabendColumnValue; +import sqlancer.databend.ast.DatabendExpression; +import sqlancer.databend.ast.DatabendJoin; +import sqlancer.databend.ast.DatabendSelect; +import sqlancer.databend.ast.DatabendTableReference; +import sqlancer.databend.gen.DatabendNewExpressionGenerator; + +public class DatabendQueryPartitioningBase + extends TernaryLogicPartitioningOracleBase + implements TestOracle { + + DatabendSchema s; + DatabendTables targetTables; + DatabendNewExpressionGenerator gen; + DatabendSelect select; + + List groupByExpression; + + public DatabendQueryPartitioningBase(DatabendGlobalState state) { + super(state); + DatabendErrors.addExpressionErrors(errors); + } + + @Override + public void check() throws SQLException { + s = state.getSchema(); + targetTables = s.getRandomTableNonEmptyAndViewTables(); + List randomColumn = targetTables.getColumns(); + + gen = new DatabendNewExpressionGenerator(state).setColumns(targetTables.getColumns()); + HashSet columnOfLeafNode = new HashSet<>(); + gen.setColumnOfLeafNode(columnOfLeafNode); + initializeTernaryPredicateVariants(); + select = new DatabendSelect(); + columnOfLeafNode + .addAll(randomColumn.stream().map(c -> new DatabendColumnValue(c, null)).collect(Collectors.toList())); + groupByExpression = new ArrayList<>(columnOfLeafNode); + + select.setFetchColumns( + randomColumn.stream().map(c -> new DatabendColumnReference(c)).collect(Collectors.toList())); + List tables = targetTables.getTables(); + List tableList = tables.stream().map(t -> new DatabendTableReference(t)) + .collect(Collectors.toList()); + if (!DatabendBugs.bug9236) { + List joins = DatabendJoin.getJoins(tableList, state); + select.setJoinList(joins.stream().collect(Collectors.toList())); + } + select.setFromList(tableList.stream().collect(Collectors.toList())); + select.setWhereClause(null); + } + + List generateFetchColumns() { + List columns = new ArrayList<>(); + if (Randomly.getBoolean()) { + columns.add(new DatabendColumnReference(new DatabendColumn("*", null, false, false))); + } else { + columns = generateRandomColumns(); + } + return columns; + } + + List generateRandomColumns() { + List columns; + columns = Randomly.nonEmptySubset(targetTables.getColumns()).stream().map(c -> new DatabendColumnReference(c)) + .collect(Collectors.toList()); + return columns; + } + + @Override + protected ExpressionGenerator getGen() { + return gen; + } + +} diff --git a/src/sqlancer/databend/test/DatabendQueryPartitioningDistinctTester.java b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningDistinctTester.java similarity index 86% rename from src/sqlancer/databend/test/DatabendQueryPartitioningDistinctTester.java rename to src/sqlancer/databend/test/tlp/DatabendQueryPartitioningDistinctTester.java index 8eae1e502..ce20ef214 100644 --- a/src/sqlancer/databend/test/DatabendQueryPartitioningDistinctTester.java +++ b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningDistinctTester.java @@ -1,4 +1,4 @@ -package sqlancer.databend.test; +package sqlancer.databend.test.tlp; import java.sql.SQLException; import java.util.ArrayList; @@ -21,6 +21,8 @@ public DatabendQueryPartitioningDistinctTester(DatabendGlobalState state) { public void check() throws SQLException { super.check(); select.setDistinct(true); + // TODO 后期可以使用and来进行扩展 + // select.setWhereClause(DatabendExprToNode.cast(gen.generateExpression(DatabendSchema.DatabendDataType.BOOLEAN))); select.setWhereClause(null); String originalQueryString = DatabendToStringVisitor.asString(select); @@ -38,7 +40,7 @@ public void check() throws SQLException { List secondResultSet = ComparatorHelper.getCombinedResultSetNoDuplicates(firstQueryString, secondQueryString, thirdQueryString, combinedString, true, state, errors); ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state, DatabendQueryPartitioningBase::canonicalizeResultValue); + state, ComparatorHelper::canonicalizeResultValue); } } diff --git a/src/sqlancer/databend/test/DatabendQueryPartitioningGroupByTester.java b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningGroupByTester.java similarity index 77% rename from src/sqlancer/databend/test/DatabendQueryPartitioningGroupByTester.java rename to src/sqlancer/databend/test/tlp/DatabendQueryPartitioningGroupByTester.java index dd25b6d7f..9d45291b3 100644 --- a/src/sqlancer/databend/test/DatabendQueryPartitioningGroupByTester.java +++ b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningGroupByTester.java @@ -1,4 +1,4 @@ -package sqlancer.databend.test; +package sqlancer.databend.test.tlp; import java.sql.SQLException; import java.util.ArrayList; @@ -7,12 +7,10 @@ import sqlancer.ComparatorHelper; import sqlancer.Randomly; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.Node; import sqlancer.databend.DatabendErrors; import sqlancer.databend.DatabendProvider.DatabendGlobalState; -import sqlancer.databend.DatabendSchema.DatabendColumn; import sqlancer.databend.DatabendToStringVisitor; +import sqlancer.databend.ast.DatabendColumnReference; import sqlancer.databend.ast.DatabendExpression; public class DatabendQueryPartitioningGroupByTester extends DatabendQueryPartitioningBase { @@ -25,7 +23,7 @@ public DatabendQueryPartitioningGroupByTester(DatabendGlobalState state) { @Override public void check() throws SQLException { super.check(); - select.setGroupByExpressions(select.getFetchColumns()); + select.setGroupByExpressions(groupByExpression); select.setWhereClause(null); String originalQueryString = DatabendToStringVisitor.asString(select); @@ -41,13 +39,13 @@ public void check() throws SQLException { List secondResultSet = ComparatorHelper.getCombinedResultSetNoDuplicates(firstQueryString, secondQueryString, thirdQueryString, combinedString, true, state, errors); ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state, DatabendQueryPartitioningBase::canonicalizeResultValue); + state, ComparatorHelper::canonicalizeResultValue); } @Override - List> generateFetchColumns() { - return Randomly.nonEmptySubset(targetTables.getColumns()).stream() - .map(c -> new ColumnReferenceNode(c)).collect(Collectors.toList()); + List generateFetchColumns() { + return Randomly.nonEmptySubset(targetTables.getColumns()).stream().map(c -> new DatabendColumnReference(c)) + .collect(Collectors.toList()); } } diff --git a/src/sqlancer/databend/test/DatabendQueryPartitioningHavingTester.java b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningHavingTester.java similarity index 76% rename from src/sqlancer/databend/test/DatabendQueryPartitioningHavingTester.java rename to src/sqlancer/databend/test/tlp/DatabendQueryPartitioningHavingTester.java index 8ca1489f2..33de8c896 100644 --- a/src/sqlancer/databend/test/DatabendQueryPartitioningHavingTester.java +++ b/src/sqlancer/databend/test/tlp/DatabendQueryPartitioningHavingTester.java @@ -1,4 +1,4 @@ -package sqlancer.databend.test; +package sqlancer.databend.test.tlp; import java.sql.SQLException; import java.util.ArrayList; @@ -7,8 +7,6 @@ import sqlancer.ComparatorHelper; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.oracle.TestOracle; import sqlancer.databend.DatabendErrors; import sqlancer.databend.DatabendProvider.DatabendGlobalState; import sqlancer.databend.DatabendSchema; @@ -16,7 +14,7 @@ import sqlancer.databend.ast.DatabendConstant; import sqlancer.databend.ast.DatabendExpression; -public class DatabendQueryPartitioningHavingTester extends DatabendQueryPartitioningBase implements TestOracle { +public class DatabendQueryPartitioningHavingTester extends DatabendQueryPartitioningBase { public DatabendQueryPartitioningHavingTester(DatabendGlobalState state) { super(state); @@ -32,17 +30,15 @@ public void check() throws SQLException { // boolean orderBy = Randomly.getBoolean(); boolean orderBy = false; // 关闭order by if (orderBy) { // TODO 生成columns.size()的子集,有个错误:order by 后不能直接union,需要包装一层select - // select.setOrderByExpressions(gen.generateOrderBys()); - List> constants = new ArrayList<>(); + // select.setOrderByClauses(gen.generateOrderBys()); + List constants = new ArrayList<>(); constants.add(new DatabendConstant.DatabendIntConstant( Randomly.smallNumber() % select.getFetchColumns().size() + 1)); - select.setOrderByExpressions(constants); + select.setOrderByClauses(constants); } - // select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); - select.setGroupByExpressions(select.getFetchColumns()); + select.setGroupByExpressions(groupByExpression); select.setHavingClause(null); String originalQueryString = DatabendToStringVisitor.asString(select); - // System.out.println(originalQueryString); List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); select.setHavingClause(predicate); @@ -55,16 +51,16 @@ public void check() throws SQLException { List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, thirdQueryString, combinedString, !orderBy, state, errors); ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state, DatabendQueryPartitioningBase::canonicalizeResultValue); + state, ComparatorHelper::canonicalizeResultValue); } @Override - protected Node generatePredicate() { + protected DatabendExpression generatePredicate() { return gen.generateHavingClause(); } @Override - List> generateFetchColumns() { + List generateFetchColumns() { return Collections.singletonList(gen.generateHavingClause()); } diff --git a/src/sqlancer/datafusion/DataFusionErrors.java b/src/sqlancer/datafusion/DataFusionErrors.java new file mode 100644 index 000000000..4be35017e --- /dev/null +++ b/src/sqlancer/datafusion/DataFusionErrors.java @@ -0,0 +1,54 @@ +package sqlancer.datafusion; + +import static sqlancer.datafusion.DataFusionUtil.dfAssert; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class DataFusionErrors { + private DataFusionErrors() { + dfAssert(false, "Utility class cannot be instantiated"); + } + + /* + * During Oracle Checks, if ANY query returns one of the following error Then the current oracle check will be + * skipped. e.g.: NoREC Q1 -> throw an expected error NoREC Q2 -> succeed Since it's a known error, `SQLancer` will + * skip this check and don't report bug. + * + * Note now it's implemented this way for simplicity This way might cause false negative, because Q1 and Q2 should + * both succeed or both fail TODO(datafusion): ensure both succeed or both fail + */ + public static List getExpectedExecutionErrors() { + ArrayList errors = new ArrayList<>(); + /* + * Expected + */ + errors.add("Error building plan"); // Randomly generated SQL is not valid and caused palning error + errors.add("Error during planning"); + errors.add("Execution error"); + errors.add("Overflow happened"); + errors.add("overflow"); + errors.add("Unsupported data type"); + errors.add("Divide by zero"); + /* + * Known bugs + */ + errors.add("to type Int64"); // https://github.com/apache/datafusion/issues/11252 + errors.add("bitwise"); // https://github.com/apache/datafusion/issues/11260 + errors.add("NestedLoopJoinExec"); // https://github.com/apache/datafusion/issues/11269 + /* + * False positives + */ + errors.add("Physical plan does not support logical expression AggregateFunction"); // False positive: when aggr + // is generated in where + // clause + + return errors; + } + + public static void registerExpectedExecutionErrors(ExpectedErrors errors) { + errors.addAll(getExpectedExecutionErrors()); + } +} diff --git a/src/sqlancer/datafusion/DataFusionOptions.java b/src/sqlancer/datafusion/DataFusionOptions.java new file mode 100644 index 000000000..582b2f658 --- /dev/null +++ b/src/sqlancer/datafusion/DataFusionOptions.java @@ -0,0 +1,21 @@ +package sqlancer.datafusion; + +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; + +@Parameters(commandDescription = "DataFusion") +public class DataFusionOptions implements DBMSSpecificOptions { + @Parameter(names = "--debug-info", description = "Show debug messages related to DataFusion", arity = 0) + public boolean showDebugInfo; + + @Override + public List getTestOracleFactory() { + return Arrays.asList(DataFusionOracleFactory.NOREC, DataFusionOracleFactory.QUERY_PARTITIONING_WHERE); + } + +} diff --git a/src/sqlancer/datafusion/DataFusionOracleFactory.java b/src/sqlancer/datafusion/DataFusionOracleFactory.java new file mode 100644 index 000000000..a7a3f21e8 --- /dev/null +++ b/src/sqlancer/datafusion/DataFusionOracleFactory.java @@ -0,0 +1,34 @@ +package sqlancer.datafusion; + +import java.sql.SQLException; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.datafusion.gen.DataFusionExpressionGenerator; + +public enum DataFusionOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create( + DataFusionProvider.DataFusionGlobalState globalState) throws SQLException { + DataFusionExpressionGenerator gen = new DataFusionExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(DataFusionErrors.getExpectedExecutionErrors()) + .with("canceling statement due to statement timeout").build(); + return new NoRECOracle<>(globalState, gen, errors); + } + }, + QUERY_PARTITIONING_WHERE { + @Override + public TestOracle create( + DataFusionProvider.DataFusionGlobalState globalState) throws SQLException { + DataFusionExpressionGenerator gen = new DataFusionExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors() + .with(DataFusionErrors.getExpectedExecutionErrors()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + } +} diff --git a/src/sqlancer/datafusion/DataFusionProvider.java b/src/sqlancer/datafusion/DataFusionProvider.java new file mode 100644 index 000000000..37328e4ba --- /dev/null +++ b/src/sqlancer/datafusion/DataFusionProvider.java @@ -0,0 +1,134 @@ +package sqlancer.datafusion; + +import static sqlancer.datafusion.DataFusionUtil.DataFusionLogger.DataFusionLogType.DML; +import static sqlancer.datafusion.DataFusionUtil.dfAssert; +import static sqlancer.datafusion.DataFusionUtil.displayTables; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.List; +import java.util.Properties; +import java.util.stream.Collectors; + +import com.google.auto.service.AutoService; + +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; +import sqlancer.SQLProviderAdapter; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; +import sqlancer.datafusion.DataFusionUtil.DataFusionInstanceID; +import sqlancer.datafusion.DataFusionUtil.DataFusionLogger; +import sqlancer.datafusion.gen.DataFusionInsertGenerator; +import sqlancer.datafusion.gen.DataFusionTableGenerator; + +@AutoService(DatabaseProvider.class) +public class DataFusionProvider extends SQLProviderAdapter { + + public DataFusionProvider() { + super(DataFusionGlobalState.class, DataFusionOptions.class); + } + + @Override + public void generateDatabase(DataFusionGlobalState globalState) throws Exception { + int tableCount = Randomly.fromOptions(1, 2, 3, 4, 5, 6, 7); + for (int i = 0; i < tableCount; i++) { + SQLQueryAdapter queryCreateRandomTable = new DataFusionTableGenerator().getQuery(globalState); + queryCreateRandomTable.execute(globalState); + globalState.updateSchema(); + globalState.dfLogger.appendToLog(DML, queryCreateRandomTable.toString() + "\n"); + } + + // Now only `INSERT` DML is supported + // If more DMLs are added later, should use`StatementExecutor` instead + // (see DuckDB's implementation for reference) + + globalState.updateSchema(); + List allTables = globalState.getSchema().getDatabaseTables(); + List allTablesName = allTables.stream().map(t -> t.getName()).collect(Collectors.toList()); + if (allTablesName.isEmpty()) { + dfAssert(false, "Generate Database failed."); + } + + // Randomly insert some data into existing tables + for (DataFusionTable table : allTables) { + int nInsertQuery = globalState.getRandomly().getInteger(0, globalState.getOptions().getMaxNumberInserts()); + + for (int i = 0; i < nInsertQuery; i++) { + SQLQueryAdapter insertQuery = null; + try { + insertQuery = DataFusionInsertGenerator.getQuery(globalState, table); + } catch (IgnoreMeException e) { + // Only for special case: table has 0 column + continue; + } + + insertQuery.execute(globalState); + globalState.dfLogger.appendToLog(DML, insertQuery.toString() + "\n"); + } + } + + // TODO(datafusion) add `DataFUsionLogType.STATE` for this whole db state log + if (globalState.getDbmsSpecificOptions().showDebugInfo) { + System.out.println(displayTables(globalState, allTablesName)); + } + } + + @Override + public SQLConnection createDatabase(DataFusionGlobalState globalState) throws SQLException { + if (globalState.getDbmsSpecificOptions().showDebugInfo) { + System.out.println("A new database get created!\n"); + } + Properties props = new Properties(); + props.setProperty("UseEncryption", "false"); + // must set 'user' and 'password' to trigger server 'do_handshake()' + props.setProperty("user", "foo"); + props.setProperty("password", "bar"); + props.setProperty("create", globalState.getDatabaseName()); // Hack: use this property to let DataFusion server + // clear the current context + String url = "jdbc:arrow-flight-sql://127.0.0.1:50051"; + Connection connection = DriverManager.getConnection(url, props); + + return new SQLConnection(connection); + } + + @Override + public String getDBMSName() { + return "datafusion"; + } + + // If run SQLancer with multiple thread + // Each thread's instance will have its own `DataFusionGlobalState` + // It will store global states including: + // JDBC connection to DataFusion server + // Logger for this thread + public static class DataFusionGlobalState extends SQLGlobalState { + public DataFusionLogger dfLogger; + DataFusionInstanceID id; + + public DataFusionGlobalState() { + // HACK: test will only run in spawned thread, not main thread + // this way redundant logger files won't be created + if (Thread.currentThread().getName().equals("main")) { + return; + } + + id = new DataFusionInstanceID(Thread.currentThread().getName()); + try { + dfLogger = new DataFusionLogger(this, id); + } catch (Exception e) { + throw new IgnoreMeException(); + } + } + + @Override + protected DataFusionSchema readSchema() throws SQLException { + return DataFusionSchema.fromConnection(getConnection(), getDatabaseName()); + } + } +} diff --git a/src/sqlancer/datafusion/DataFusionSchema.java b/src/sqlancer/datafusion/DataFusionSchema.java new file mode 100644 index 000000000..d02e80c30 --- /dev/null +++ b/src/sqlancer/datafusion/DataFusionSchema.java @@ -0,0 +1,195 @@ +package sqlancer.datafusion; + +import static sqlancer.datafusion.DataFusionUtil.dfAssert; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.TableIndex; +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; +import sqlancer.datafusion.ast.DataFusionConstant; +import sqlancer.datafusion.ast.DataFusionExpression; + +public class DataFusionSchema extends AbstractSchema { + + public DataFusionSchema(List databaseTables) { + super(databaseTables); + } + + // update existing tables in DB by query again + // (like `show tables;`) + public static DataFusionSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + List databaseTables = new ArrayList<>(); + List tableNames = getTableNames(con); + + for (String tableName : tableNames) { + List databaseColumns = getTableColumns(con, tableName); + boolean isView = matchesViewName(tableName); + DataFusionTable t = new DataFusionTable(tableName, databaseColumns, isView); + for (DataFusionColumn c : databaseColumns) { + c.setTable(t); + } + + databaseTables.add(t); + } + + return new DataFusionSchema(databaseTables); + } + + private static List getTableNames(SQLConnection con) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery("select table_name " + "from information_schema.tables " + + "where table_schema='public'" + "order by table_name; ")) { + while (rs.next()) { + tableNames.add(rs.getString(1)); + } + } + } + return tableNames; + } + + private static List getTableColumns(SQLConnection con, String tableName) throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery( + String.format("select * from information_schema.columns where table_name = '%s';", tableName))) { + while (rs.next()) { + String columnName = rs.getString("column_name"); + String dataType = rs.getString("data_type"); + boolean isNullable = rs.getString("is_nullable").contentEquals("YES"); + + DataFusionColumn c = new DataFusionColumn(columnName, + DataFusionDataType.parseFromDataFusionCatalog(dataType), isNullable); + columns.add(c); + } + } + } + + return columns; + } + + /* + * When adding a new type: 1. Update all methods inside this enum 2. Update all `DataFusionBaseExpr`'s signature, if + * it can support new type (in `DataFusionBaseExprFactory.java` + * + * Types are 'SQL DataType' in DataFusion's documentation + * https://datafusion.apache.org/user-guide/sql/data_types.html + */ + public enum DataFusionDataType { + + BIGINT, DOUBLE, BOOLEAN, NULL; + + public static DataFusionDataType getRandomWithoutNull() { + DataFusionDataType dt; + do { + dt = Randomly.fromOptions(values()); + } while (dt == DataFusionDataType.NULL); + return dt; + } + + // How to parse type in DataFusion's catalog to `DataFusionDataType` + // As displayed in: + // create table t1(v1 int, v2 bigint); + // select table_name, column_name, data_type from information_schema.columns; + public static DataFusionDataType parseFromDataFusionCatalog(String typeString) { + switch (typeString) { + case "Int64": + return DataFusionDataType.BIGINT; + case "Float64": + return DataFusionDataType.DOUBLE; + case "Boolean": + return DataFusionDataType.BOOLEAN; + default: + dfAssert(false, "Unreachable. All branches should be eovered"); + } + + dfAssert(false, "Unreachable. All branches should be eovered"); + return null; + } + + // TODO(datafusion) lots of hack here, should build our own Randomly later + public DataFusionExpression getRandomConstant(DataFusionGlobalState state) { + if (Randomly.getBooleanWithSmallProbability()) { + return DataFusionConstant.createNullConstant(); + } + switch (this) { + case BIGINT: + return DataFusionConstant.createIntConstant(state.getRandomly().getInteger()); + case BOOLEAN: + return new DataFusionConstant.DataFusionBooleanConstant(Randomly.getBoolean()); + case DOUBLE: + if (Randomly.getBoolean()) { + if (Randomly.getBoolean()) { + Double randomDouble = state.getRandomly().getDouble(); // [0.0, 1.0); + Double scaledDouble = (randomDouble - 0.5) * 2 * Double.MAX_VALUE; + return new DataFusionConstant.DataFusionDoubleConstant(scaledDouble); + } + String doubleStr = Randomly.fromOptions("'NaN'::Double", "'+Inf'::Double", "'-Inf'::Double", "-0.0", + "+0.0"); + return new DataFusionConstant.DataFusionDoubleConstant(doubleStr); + } + + return new DataFusionConstant.DataFusionDoubleConstant(state.getRandomly().getDouble()); + case NULL: + return DataFusionConstant.createNullConstant(); + default: + dfAssert(false, "Unreachable. All branches should be eovered"); + } + + dfAssert(false, "Unreachable. All branches should be eovered"); + return DataFusionConstant.createNullConstant(); + } + } + + public static class DataFusionColumn extends AbstractTableColumn { + + private final boolean isNullable; + + public DataFusionColumn(String name, DataFusionDataType columnType, boolean isNullable) { + super(name, null, columnType); + this.isNullable = isNullable; + } + + public boolean isNullable() { + return isNullable; + } + + } + + public static class DataFusionTable + extends AbstractRelationalTable { + + public DataFusionTable(String tableName, List columns, boolean isView) { + super(tableName, columns, Collections.emptyList(), isView); + } + + public static List getAllColumns(List tables) { + return tables.stream().map(AbstractTable::getColumns).flatMap(List::stream).collect(Collectors.toList()); + } + + public static List getRandomColumns(List tables) { + if (Randomly.getBooleanWithRatherLowProbability()) { + return Arrays.asList(new DataFusionColumn("*", DataFusionDataType.NULL, true)); + } + + List allColumns = getAllColumns(tables); + + return Randomly.nonEmptySubset(allColumns); + } + } + +} diff --git a/src/sqlancer/datafusion/DataFusionToStringVisitor.java b/src/sqlancer/datafusion/DataFusionToStringVisitor.java new file mode 100644 index 000000000..7d0d1b1b6 --- /dev/null +++ b/src/sqlancer/datafusion/DataFusionToStringVisitor.java @@ -0,0 +1,98 @@ +package sqlancer.datafusion; + +import java.util.List; + +import sqlancer.common.ast.newast.NewToStringVisitor; +import sqlancer.datafusion.ast.DataFusionConstant; +import sqlancer.datafusion.ast.DataFusionExpression; +import sqlancer.datafusion.ast.DataFusionJoin; +import sqlancer.datafusion.ast.DataFusionSelect; + +public class DataFusionToStringVisitor extends NewToStringVisitor { + + public static String asString(DataFusionExpression expr) { + DataFusionToStringVisitor visitor = new DataFusionToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } + + public static String asString(List exprs) { + DataFusionToStringVisitor visitor = new DataFusionToStringVisitor(); + visitor.visit(exprs); + return visitor.get(); + } + + @Override + public void visitSpecific(DataFusionExpression expr) { + if (expr instanceof DataFusionConstant) { + visit((DataFusionConstant) expr); + } else if (expr instanceof DataFusionSelect) { + visit((DataFusionSelect) expr); + } else if (expr instanceof DataFusionJoin) { + visit((DataFusionJoin) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + private void visit(DataFusionJoin join) { + visit((DataFusionExpression) join.getLeftTable()); + sb.append(" "); + sb.append(join.getJoinType()); + sb.append(" "); + + sb.append(" JOIN "); + visit((DataFusionExpression) join.getRightTable()); + if (join.getOnCondition() != null) { + sb.append(" ON "); + visit(join.getOnCondition()); + } + } + + private void visit(DataFusionConstant constant) { + sb.append(constant.toString()); + } + + private void visit(DataFusionSelect select) { + sb.append("SELECT "); + if (select.fetchColumnsString.isPresent()) { + sb.append(select.fetchColumnsString.get()); + } else { + visit(select.getFetchColumns()); + } + + sb.append(" FROM "); + visit(select.getFromList()); + if (!select.getFromList().isEmpty() && !select.getJoinList().isEmpty()) { + sb.append(", "); + } + if (!select.getJoinList().isEmpty()) { + visit(select.getJoinList()); + } + if (select.getWhereClause() != null) { + sb.append(" WHERE "); + visit(select.getWhereClause()); + } + if (!select.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(select.getGroupByExpressions()); + } + if (select.getHavingClause() != null) { + sb.append(" HAVING "); + visit(select.getHavingClause()); + } + if (!select.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(select.getOrderByClauses()); + } + if (select.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(select.getLimitClause()); + } + if (select.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(select.getOffsetClause()); + } + } + +} diff --git a/src/sqlancer/datafusion/DataFusionUtil.java b/src/sqlancer/datafusion/DataFusionUtil.java new file mode 100644 index 000000000..8761bec9b --- /dev/null +++ b/src/sqlancer/datafusion/DataFusionUtil.java @@ -0,0 +1,190 @@ +package sqlancer.datafusion; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Paths; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.List; + +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; + +public final class DataFusionUtil { + private DataFusionUtil() { + dfAssert(false, "Utility class cannot be instantiated"); + } + + // Display tables in `fromTableNames` + public static String displayTables(DataFusionGlobalState state, List fromTableNames) { + StringBuilder resultStringBuilder = new StringBuilder(); + for (String tableName : fromTableNames) { + String query = String.format("select * from %s", tableName); + try (Statement stat = state.getConnection().createStatement(); + ResultSet wholeTable = stat.executeQuery(query)) { + + ResultSetMetaData metaData = wholeTable.getMetaData(); + int columnCount = metaData.getColumnCount(); + + resultStringBuilder.append("Table: ").append(tableName).append("\n"); + for (int i = 1; i <= columnCount; i++) { + resultStringBuilder.append(metaData.getColumnName(i)).append(" (") + .append(metaData.getColumnTypeName(i)).append(")"); + if (i < columnCount) { + resultStringBuilder.append(", "); + } + } + resultStringBuilder.append("\n"); + + while (wholeTable.next()) { + for (int i = 1; i <= columnCount; i++) { + resultStringBuilder.append(wholeTable.getString(i)); + if (i < columnCount) { + resultStringBuilder.append(", "); + } + } + resultStringBuilder.append("\n"); + } + resultStringBuilder.append("----------------------------------------\n\n"); + + } catch (SQLException err) { + resultStringBuilder.append("Table: ").append(tableName).append("\n"); + resultStringBuilder.append("----------------------------------------\n\n"); + // resultStringBuilder.append("Error retrieving data from table ").append(tableName).append(": + // ").append(err.getMessage()).append("\n"); + } + } + + return resultStringBuilder.toString(); + } + + // During development, you might want to manually let this function call exit(1) to fail fast + public static void dfAssert(boolean condition, String message) { + if (!condition) { + // // Development mode assertion failure + // String methodName = Thread.currentThread().getStackTrace()[2]// .getMethodName(); + // System.err.println("DataFusion assertion failed in function '" + methodName + "': " + message); + // exit(1); + + throw new AssertionError(message); + } + } + + /* + * Fetch all DMLs from logs/database*-cur.log + */ + public static String getReplay(String dbname) { + String path = "./logs/datafusion/" + dbname + "-cur.log"; + String absolutePath = Paths.get(path).toAbsolutePath().toString(); + + StringBuilder reproducer = new StringBuilder(); + + try (BufferedReader reader = new BufferedReader(new FileReader(absolutePath))) { + String line; + while ((line = reader.readLine()) != null) { + // Check if the line contains the /*DML*/ marker + if (line.contains("/*DML*/")) { + reproducer.append(line).append("\n"); + } + } + } catch (IOException e) { + System.err.println("Error reading from file: " + e.getMessage()); + } + + return reproducer.toString(); + } + + // UID for different fuzzer runs + public static class DataFusionInstanceID { + private final String id; + + public DataFusionInstanceID(String dfID) { + id = dfID; + } + + @Override + public String toString() { + return id; // Return the id field when toString is called + } + } + + /* + * Extra logs stored in 'logs/datafusion_custom_log/' In case re-run overwrite previous logs + */ + public static class DataFusionLogger { + private final DataFusionInstanceID dfID; + private final DataFusionGlobalState state; + /* + * Log file handles + */ + private final File errorLogFile; + + public DataFusionLogger(DataFusionGlobalState globalState, DataFusionInstanceID id) throws Exception { + this.state = globalState; + this.dfID = id; + + // Setup datafusion_custom_log folder + File baseDir = new File("logs/datafusion_custom_log/"); + if (!baseDir.exists() && !baseDir.mkdirs()) { + throw new IOException("Failed to create 'datafusion_custom_log' directory/"); + } + + // Setup error.log + errorLogFile = new File(baseDir, "error_report.log"); + errorLogFile.createNewFile(); + } + + // Caller is responsible for adding '\n' at the end of logContent + public void appendToLog(DataFusionLogType logType, String logContent) { + FileWriter logFileWriter = null; + + // Determine which log file to use based on the LogType + String logLineHeader = ""; + switch (logType) { + case ERROR: + try { + logFileWriter = new FileWriter(errorLogFile, true); + } catch (IOException e) { + dfAssert(false, "Failed to create FileWriter for errorLogFIle"); + } + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); + String formattedDateTime = LocalDateTime.now().format(formatter); + logLineHeader = "Run@" + formattedDateTime + " (" + dfID + ")\n"; + break; + case DML: + logFileWriter = state.getLogger().getCurrentFileWriter(); + logLineHeader = "/*DML*/"; + break; + case SELECT: + logFileWriter = state.getLogger().getCurrentFileWriter(); + break; + default: + dfAssert(false, "All branch should be covered"); + } + + // Append content to the appropriate log file + if (logFileWriter != null) { + try { + logFileWriter.write(logLineHeader); + logFileWriter.write(logContent); + logFileWriter.flush(); + } catch (IOException e) { + String err = "Failed to write to " + logType + " log: " + e.getMessage(); + dfAssert(false, err); + } + } else { + dfAssert(false, "appending to log failed"); + } + } + + public enum DataFusionLogType { + ERROR, DML, SELECT + } + } +} diff --git a/src/sqlancer/datafusion/ast/DataFusionBinaryOperation.java b/src/sqlancer/datafusion/ast/DataFusionBinaryOperation.java new file mode 100644 index 000000000..e59676be8 --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionBinaryOperation.java @@ -0,0 +1,11 @@ +package sqlancer.datafusion.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class DataFusionBinaryOperation extends NewBinaryOperatorNode + implements DataFusionExpression { + public DataFusionBinaryOperation(DataFusionExpression left, DataFusionExpression right, Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/datafusion/ast/DataFusionColumnReference.java b/src/sqlancer/datafusion/ast/DataFusionColumnReference.java new file mode 100644 index 000000000..2391ef694 --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionColumnReference.java @@ -0,0 +1,12 @@ +package sqlancer.datafusion.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.datafusion.DataFusionSchema; + +public class DataFusionColumnReference extends + ColumnReferenceNode implements DataFusionExpression { + public DataFusionColumnReference(DataFusionSchema.DataFusionColumn column) { + super(column); + } + +} diff --git a/src/sqlancer/datafusion/ast/DataFusionConstant.java b/src/sqlancer/datafusion/ast/DataFusionConstant.java new file mode 100644 index 000000000..90a997fcf --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionConstant.java @@ -0,0 +1,97 @@ +package sqlancer.datafusion.ast; + +public class DataFusionConstant implements DataFusionExpression { + + private DataFusionConstant() { + } + + public static DataFusionExpression createIntConstant(long val) { + return new DataFusionIntConstant(val); + } + + public static DataFusionExpression createNullConstant() { + return new DataFusionNullConstant(); + } + + public static class DataFusionNullConstant extends DataFusionConstant { + + @Override + public String toString() { + return "NULL"; + } + + } + + public static class DataFusionIntConstant extends DataFusionConstant { + + private final long value; + + public DataFusionIntConstant(long value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + public long getValue() { + return value; + } + + } + + public static class DataFusionDoubleConstant extends DataFusionConstant { + + private final String valueStr; + + public DataFusionDoubleConstant(double value) { + if (value == Double.POSITIVE_INFINITY) { + valueStr = "'+Inf'::Double"; + } else if (value == Double.NEGATIVE_INFINITY) { + valueStr = "'-Inf'::Double"; + } else if (Double.isNaN(value)) { + valueStr = "'NaN'::Double"; + } else if (Double.compare(value, -0.0) == 0) { + valueStr = "-0.0"; + } else { + valueStr = String.valueOf(value); + } + } + + // Make it more convenient to construct special value like -0, NaN, etc. + public DataFusionDoubleConstant(String valueStr) { + this.valueStr = valueStr; + } + + @Override + public String toString() { + return valueStr; + } + + } + + public static class DataFusionBooleanConstant extends DataFusionConstant { + + private final boolean value; + + public DataFusionBooleanConstant(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + @Override + public String toString() { + if (value) { + return "true"; + } else { + return "false"; + } + } + + } + +} diff --git a/src/sqlancer/datafusion/ast/DataFusionExpression.java b/src/sqlancer/datafusion/ast/DataFusionExpression.java new file mode 100644 index 000000000..eaf84ba1d --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionExpression.java @@ -0,0 +1,8 @@ +package sqlancer.datafusion.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.datafusion.DataFusionSchema.DataFusionColumn; + +public interface DataFusionExpression extends Expression { + +} diff --git a/src/sqlancer/datafusion/ast/DataFusionFunction.java b/src/sqlancer/datafusion/ast/DataFusionFunction.java new file mode 100644 index 000000000..130fc04bf --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionFunction.java @@ -0,0 +1,11 @@ +package sqlancer.datafusion.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class DataFusionFunction extends NewFunctionNode implements DataFusionExpression { + public DataFusionFunction(List args, F func) { + super(args, func); + } +} diff --git a/src/sqlancer/datafusion/ast/DataFusionJoin.java b/src/sqlancer/datafusion/ast/DataFusionJoin.java new file mode 100644 index 000000000..05718c5ad --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionJoin.java @@ -0,0 +1,91 @@ +package sqlancer.datafusion.ast; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionSchema; +import sqlancer.datafusion.DataFusionSchema.DataFusionColumn; +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; +import sqlancer.datafusion.gen.DataFusionExpressionGenerator; + +/* + NOT IMPLEMENTED YET + */ +public class DataFusionJoin + implements DataFusionExpression, Join { + + private final DataFusionTableReference leftTable; + private final DataFusionTableReference rightTable; + private final JoinType joinType; + private DataFusionExpression onCondition; + + public DataFusionJoin(DataFusionTableReference leftTable, DataFusionTableReference rightTable, JoinType joinType, + DataFusionExpression whereCondition) { + this.leftTable = leftTable; + this.rightTable = rightTable; + this.joinType = joinType; + this.onCondition = whereCondition; + } + + public static List getJoins(List tableList, + DataFusionGlobalState globalState) { + // [t1_join_t2, t1_join_t3, ...] + List joinExpressions = new ArrayList<>(); + while (tableList.size() >= 2 && Randomly.getBooleanWithRatherLowProbability()) { + DataFusionTableReference leftTable = tableList.remove(0); + DataFusionTableReference rightTable = tableList.remove(0); + List columns = new ArrayList<>(leftTable.getTable().getColumns()); + columns.addAll(rightTable.getTable().getColumns()); + // TODO(datafusion) this `joinGen` can generate super chaotic exprsions, maybe we should make it more like a + // normal join expression + DataFusionExpressionGenerator joinGen = new DataFusionExpressionGenerator(globalState).setColumns(columns); + switch (DataFusionJoin.JoinType.getRandom()) { + case INNER: + joinExpressions.add(DataFusionJoin.createInnerJoin(leftTable, rightTable, + joinGen.generateExpression(DataFusionSchema.DataFusionDataType.BOOLEAN))); + break; + default: + throw new AssertionError(); + } + } + return joinExpressions; + } + + public static DataFusionJoin createInnerJoin(DataFusionTableReference left, DataFusionTableReference right, + DataFusionExpression predicate) { + return new DataFusionJoin(left, right, JoinType.INNER, predicate); + } + + public DataFusionTableReference getLeftTable() { + return leftTable; + } + + public DataFusionTableReference getRightTable() { + return rightTable; + } + + public JoinType getJoinType() { + return joinType; + } + + public DataFusionExpression getOnCondition() { + return onCondition; + } + + public enum JoinType { + INNER; + // NATURAL, LEFT, RIGHT; + + public static JoinType getRandom() { + return Randomly.fromOptions(values()); + } + } + + @Override + public void setOnClause(DataFusionExpression onClause) { + onCondition = onClause; + } +} diff --git a/src/sqlancer/datafusion/ast/DataFusionSelect.java b/src/sqlancer/datafusion/ast/DataFusionSelect.java new file mode 100644 index 000000000..75bbb26ae --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionSelect.java @@ -0,0 +1,43 @@ +package sqlancer.datafusion.ast; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.datafusion.DataFusionSchema.DataFusionColumn; +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; +import sqlancer.datafusion.DataFusionToStringVisitor; + +public class DataFusionSelect extends SelectBase implements DataFusionExpression, + Select { + public Optional fetchColumnsString = Optional.empty(); // When available, override `fetchColumns` in base + + /* + * If set fetch columns with string It will override `fetchColumns` in base class when + * `DataFusionToStringVisitor.asString()` is called + * + * This method can be helpful to mutate select in oracle checks: SELECT [expr] ... -> SELECT SUM[expr] + */ + public void setFetchColumnsString(String selectExpr) { + this.fetchColumnsString = Optional.of(selectExpr); + } + + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (DataFusionExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (DataFusionJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return DataFusionToStringVisitor.asString(this); + } +} diff --git a/src/sqlancer/datafusion/ast/DataFusionTableReference.java b/src/sqlancer/datafusion/ast/DataFusionTableReference.java new file mode 100644 index 000000000..14445576a --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionTableReference.java @@ -0,0 +1,11 @@ +package sqlancer.datafusion.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.datafusion.DataFusionSchema; + +public class DataFusionTableReference extends TableReferenceNode + implements DataFusionExpression { + public DataFusionTableReference(DataFusionSchema.DataFusionTable table) { + super(table); + } +} diff --git a/src/sqlancer/datafusion/ast/DataFusionUnaryPostfixOperation.java b/src/sqlancer/datafusion/ast/DataFusionUnaryPostfixOperation.java new file mode 100644 index 000000000..ba5629460 --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionUnaryPostfixOperation.java @@ -0,0 +1,11 @@ +package sqlancer.datafusion.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; + +public class DataFusionUnaryPostfixOperation extends NewUnaryPostfixOperatorNode + implements DataFusionExpression { + public DataFusionUnaryPostfixOperation(DataFusionExpression expr, BinaryOperatorNode.Operator op) { + super(expr, op); + } +} diff --git a/src/sqlancer/datafusion/ast/DataFusionUnaryPrefixOperation.java b/src/sqlancer/datafusion/ast/DataFusionUnaryPrefixOperation.java new file mode 100644 index 000000000..7109a12c2 --- /dev/null +++ b/src/sqlancer/datafusion/ast/DataFusionUnaryPrefixOperation.java @@ -0,0 +1,11 @@ +package sqlancer.datafusion.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; + +public class DataFusionUnaryPrefixOperation extends NewUnaryPrefixOperatorNode + implements DataFusionExpression { + public DataFusionUnaryPrefixOperation(DataFusionExpression expr, BinaryOperatorNode.Operator operator) { + super(expr, operator); + } +} diff --git a/src/sqlancer/datafusion/gen/DataFusionBaseExpr.java b/src/sqlancer/datafusion/gen/DataFusionBaseExpr.java new file mode 100644 index 000000000..0be57486e --- /dev/null +++ b/src/sqlancer/datafusion/gen/DataFusionBaseExpr.java @@ -0,0 +1,258 @@ +package sqlancer.datafusion.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.datafusion.DataFusionSchema.DataFusionDataType; + +/* + Notes for adding new `DataFusionBaseExpr` support: + + Expression ASTs are constructed with Node<> class, it can traverse expression and convert AST to String representation + `DataFusionBaseExpr` implements `Operator<>` class, which is a field inside `Node<>` class, it includes operator properties like number of arguments, signature, or is this operator prefix/suffix, etc. + + To add new base expr (scalar functions, operators like '<<', 'AND' are all base expr): + 1. Add an enum variant to `DataFusionBaseExprType` + 2. Update `DataFusionBaseExprFactory.java` + (If a function support different argument number, make a new entry for each one. e.g. round scalar function support round(3.14) / round(3.14, 1), so it should be enum FUNC_ROUND1, FUNC_ROUND2) + */ +public class DataFusionBaseExpr implements Operator { + public String name; + public int nArgs; // number of input arguments + public DataFusionBaseExprCategory exprType; + public List possibleReturnTypes; + public List argTypes; + public boolean isVariadic; // Function supports arbitrary number of arguments, if set to `true`, it will + // override `nArgs` + + // Primary constructor + DataFusionBaseExpr(String name, int nArgs, DataFusionBaseExprCategory exprCategory, + List possibleReturnTypes, List argTypes, boolean isVariadic) { + this.name = name; + this.nArgs = nArgs; + this.exprType = exprCategory; + this.possibleReturnTypes = possibleReturnTypes; + this.argTypes = argTypes; + this.isVariadic = isVariadic; + } + + // Overloaded constructor assuming 'isVariadic' is false + DataFusionBaseExpr(String name, int nArgs, DataFusionBaseExprCategory exprCategory, + List possibleReturnTypes, List argTypes) { + this(name, nArgs, exprCategory, possibleReturnTypes, argTypes, false); + } + + public static DataFusionBaseExpr createCommonNumericFuncSingleArg(String name) { + return new DataFusionBaseExpr(name, 1, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList(new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + } + + public static DataFusionBaseExpr createCommonNumericAggrFuncSingleArg(String name) { + return new DataFusionBaseExpr(name, 1, DataFusionBaseExprCategory.AGGREGATE, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList(new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + } + + public static DataFusionBaseExpr createCommonNumericFuncTwoArgs(String name) { + return new DataFusionBaseExpr(name, 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + } + + @Override + public String getTextRepresentation() { + return name; + } + + @Override + public String toString() { + return name; + } + + /* + * Class/Enum for `DataFusionBaseExpr` fields + */ + // Used to construct `src.common.ast.*Node` + public enum DataFusionBaseExprCategory { + UNARY_PREFIX, UNARY_POSTFIX, BINARY, FUNC, AGGREGATE + } + + /* + * Operators reference: https://datafusion.apache.org/user-guide/sql/operators.html Scalar functions: + * https://datafusion.apache.org/user-guide/sql/scalar_functions.html + */ + public enum DataFusionBaseExprType { + // Null Operators + IS_NULL, // IS NULL + IS_NOT_NULL, // IS NOT NULL + + // Numeric Operators + ADD, // 1 + 1 + SUB, // 1 - 1 + MULTIPLICATION, // 2 * 3 + DIVISION, // 8 / 4 + MODULO, // 5 % 3 + + // Comparison Operators + EQUAL, // 1 = 1 + EQUAL2, // 1 == 1 + NOT_EQUAL, // 1 != 2 + LESS_THAN, // 3 < 4 + LESS_THAN_OR_EQUAL_TO, // 3 <= 3 + GREATER_THAN, // 6 > 5 + GREATER_THAN_OR_EQUAL_TO, // 5 >= 5 + + // Distinctness operators + IS_DISTINCT_FROM, // 0 IS DISTINCT FROM NULL + IS_NOT_DISTINCT_FROM, // NULL IS NOT DISTINCT FROM NULL + + /* + * // Regular expression match operators REGEX_MATCH, // 'datafusion' ~ '^datafusion(-cli)*' + * REGEX_CASE_INSENSITIVE_MATCH, // 'datafusion' ~* '^DATAFUSION(-cli)*' NOT_REGEX_MATCH, // 'datafusion' !~ + * '^DATAFUSION(-cli)*' NOT_REGEX_CASE_INSENSITIVE_MATCH, // 'datafusion' !~* '^DATAFUSION(-cli)+' + * + * // Like pattern match operators LIKE_MATCH, // 'datafusion' ~~ 'dat_f%n' CASE_INSENSITIVE_LIKE_MATCH, // + * 'datafusion' ~~* 'Dat_F%n' NOT_LIKE_MATCH, // 'datafusion' !~~ 'Dat_F%n' NOT_CASE_INSENSITIVE_LIKE_MATCH // + * 'datafusion' !~~* 'Dat%F_n' + */ + + // Logical Operators + AND, // true and true + OR, // true or false + + // Bitwise Operators + BITWISE_AND, // 5 & 3 + BITWISE_OR, // 5 | 3 + BITWISE_XOR, // 5 ^ 3 + BITWISE_SHIFT_RIGHT, // 5 >> 3 + BITWISE_SHIFT_LEFT, // 5 << 3 + + /* + * // Other operators STRING_CONCATENATION, // 'Hello, ' || 'DataFusion!' ARRAY_CONTAINS, // + * make_array(1,2,3) @> make_array(1,3) ARRAY_IS_CONTAINED_BY // make_array(1,3) <@ make_array(1,2,3) + */ + + // Unary Prefix Operators + NOT, // NOT true + PLUS, // +7 + MINUS, // -3 + + /* + * Scalar Functions + */ + + // Math Functions + FUNC_ABS, // abs(-10) + FUNC_ACOS, // acos(1) + FUNC_ACOSH, // acosh(10) + FUNC_ASIN, // asin(1) + FUNC_ASINH, // asinh(1) + FUNC_ATAN, // atan(1) + FUNC_ATANH, // atanh(0.5) + FUNC_ATAN2, // atan2(10, 10) + FUNC_CBRT, // cbrt(27) + FUNC_CEIL, // ceil(9.2) + FUNC_COS, // cos(π/3) + FUNC_COSH, // cosh(0) + FUNC_DEGREES, // degrees(π) + FUNC_EXP, // exp(1) + FUNC_FACTORIAL, // factorial(5) + FUNC_FLOOR, // floor(3.7) + FUNC_GCD, // gcd(8, 12) + FUNC_ISNAN, // isnan(NaN) + FUNC_ISZERO, // iszero(0.0) + FUNC_LCM, // lcm(5, 15) + FUNC_LN, // ln(1) + FUNC_LOG, // log(100) + FUNC_LOG_WITH_BASE, // log(10, 100) + FUNC_LOG10, // log10(100) + FUNC_LOG2, // log2(32) + FUNC_NANVL, // nanvl(NaN, 3) + FUNC_PI, // pi() + FUNC_POW, // pow(2, 3) + FUNC_POWER, // power(2, 3) + FUNC_RADIANS, // radians(180) + // FUNC_RANDOM, // random() disabled because it's non-deterministic + FUNC_ROUND, // round(3.14159) + FUNC_ROUND_WITH_DECIMAL, // round(3.14159, 2) + FUNC_SIGNUM, // signum(-10) + FUNC_SIN, // sin(π/2) + FUNC_SINH, // sinh(1) + FUNC_SQRT, // sqrt(16) + FUNC_TAN, // tan(π/4) + FUNC_TANH, // tanh(1) + FUNC_TRUNC, // trunc(3.14159) + FUNC_TRUNC_WITH_DECIMAL, // trunc(3.14159, 2) + + // Conditional Functions + FUNC_COALESCE, // coalesce(NULL, 'default value') + FUNC_NULLIF, // nullif('value', 'value') + FUNC_NVL, // nvl(NULL, 'default value') + FUNC_NVL2, // nvl2('not null', 'return if not null', 'return if null') + FUNC_IFNULL, // ifnull(NULL, 'default value') + + // String Functions + + // Time and Date Functions + + // Array Functions + + // Struct Functions + + // Hashing Functions + + // Other Functions + + // Aggregate Functions + AGGR_MIN, AGGR_MAX, AGGR_SUM, AGGR_AVG, AGGR_COUNT, + } + + /* + * Because expressions are constructed in a top-down way, we have to infer argument type given return type. For each + * arg, if its corresponding element is `SameAsReturnType`, it should be the same as the type of expression's + * evaluated value. Else, it should be specific `DataFusionDataType` + * + * e.g. let's say we're generating a round(num, digit) of double type, its `argTypes` is: Arrays.asList( new + * ArgumentType.SameAsReturnType(), // First arg type as return type new ArgumentType.Fixed(new + * ArrayList<>(Array.asList(DataFusionDataType.INT)) // Second arg always Integer ) it means: its first argument + * should be the same as returned type (double), and the second arg should always be Int. + * + * Random expression generator's policy: SameAsReturnType -> generate an expr with the same type as its return type + * SameAsReturnType -> generate an expr with the same type as its 1st arg type Fixed(type1, type2, ... typeN) -> + * randomly choose a possible type (It will also generate completely random type/null ~10%) + * + * Note this defination is not comprehensive for native `DataFusion` types. It's just for simplicity and should + * cover most common cases + */ + public abstract static class ArgumentType { + private ArgumentType() { + } + + public static class SameAsReturnType extends ArgumentType { + } + + public static class SameAsFirstArgType extends ArgumentType { + } + + public static class Fixed extends ArgumentType { + public List fixedType; // It's a list to support different possible arg types. + + public Fixed(List fixedType) { + this.fixedType = fixedType; + } + + public List getType() { + return fixedType; + } + } + } +} diff --git a/src/sqlancer/datafusion/gen/DataFusionBaseExprFactory.java b/src/sqlancer/datafusion/gen/DataFusionBaseExprFactory.java new file mode 100644 index 000000000..d3fe39972 --- /dev/null +++ b/src/sqlancer/datafusion/gen/DataFusionBaseExprFactory.java @@ -0,0 +1,391 @@ +package sqlancer.datafusion.gen; + +import static sqlancer.datafusion.DataFusionUtil.dfAssert; +import static sqlancer.datafusion.gen.DataFusionBaseExpr.createCommonNumericAggrFuncSingleArg; +import static sqlancer.datafusion.gen.DataFusionBaseExpr.createCommonNumericFuncSingleArg; +import static sqlancer.datafusion.gen.DataFusionBaseExpr.createCommonNumericFuncTwoArgs; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.datafusion.DataFusionSchema.DataFusionDataType; +import sqlancer.datafusion.gen.DataFusionBaseExpr.ArgumentType; +import sqlancer.datafusion.gen.DataFusionBaseExpr.DataFusionBaseExprCategory; +import sqlancer.datafusion.gen.DataFusionBaseExpr.DataFusionBaseExprType; + +public final class DataFusionBaseExprFactory { + private DataFusionBaseExprFactory() { + dfAssert(false, "Utility class cannot be instantiated"); + } + + public static DataFusionBaseExpr createExpr(DataFusionBaseExprType type) { + switch (type) { + case IS_NULL: + return new DataFusionBaseExpr("IS NULL", 1, DataFusionBaseExprCategory.UNARY_POSTFIX, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.DOUBLE, DataFusionDataType.BIGINT, DataFusionDataType.NULL))))); + case IS_NOT_NULL: + return new DataFusionBaseExpr("IS NOT NULL", 1, DataFusionBaseExprCategory.UNARY_POSTFIX, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.DOUBLE, DataFusionDataType.BIGINT, DataFusionDataType.NULL))))); + case BITWISE_AND: + return new DataFusionBaseExpr("&", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case BITWISE_OR: + return new DataFusionBaseExpr("|", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case BITWISE_XOR: + return new DataFusionBaseExpr("^", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case BITWISE_SHIFT_RIGHT: + return new DataFusionBaseExpr(">>", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case BITWISE_SHIFT_LEFT: + return new DataFusionBaseExpr("<<", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case NOT: + return new DataFusionBaseExpr("NOT", 1, DataFusionBaseExprCategory.UNARY_PREFIX, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN))))); + case PLUS: // unary prefix '+' + return new DataFusionBaseExpr("+", 1, DataFusionBaseExprCategory.UNARY_PREFIX, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList(new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case MINUS: // unary prefix '-' + return new DataFusionBaseExpr("-", 1, DataFusionBaseExprCategory.UNARY_PREFIX, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList(new ArgumentType.Fixed( + new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case MULTIPLICATION: + return new DataFusionBaseExpr("*", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case DIVISION: + return new DataFusionBaseExpr("/", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case MODULO: + return new DataFusionBaseExpr("%", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case EQUAL: + return new DataFusionBaseExpr("=", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.SameAsFirstArgType())); + case EQUAL2: + return new DataFusionBaseExpr("==", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.SameAsFirstArgType())); + case NOT_EQUAL: + return new DataFusionBaseExpr("!=", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.SameAsFirstArgType())); + case LESS_THAN: + return new DataFusionBaseExpr("<", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.SameAsFirstArgType())); + case LESS_THAN_OR_EQUAL_TO: + return new DataFusionBaseExpr("<=", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.SameAsFirstArgType())); + case GREATER_THAN: + return new DataFusionBaseExpr(">", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.SameAsFirstArgType())); + case GREATER_THAN_OR_EQUAL_TO: + return new DataFusionBaseExpr(">=", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.SameAsFirstArgType())); + case IS_DISTINCT_FROM: + return new DataFusionBaseExpr("IS DISTINCT FROM", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.SameAsFirstArgType())); + case IS_NOT_DISTINCT_FROM: + return new DataFusionBaseExpr("IS NOT DISTINCT FROM", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT, + DataFusionDataType.DOUBLE, DataFusionDataType.BOOLEAN))), + new ArgumentType.SameAsFirstArgType())); + case AND: + return new DataFusionBaseExpr("AND", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN))), // arg1 + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN))) // arg2 + )); + case OR: + return new DataFusionBaseExpr("OR", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BOOLEAN), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN))), // arg1 + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN))) // arg2 + )); + case ADD: // binary arithmetic operator '+' + return new DataFusionBaseExpr("+", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))), // arg1 + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))) // arg2 + )); + case SUB: // binary arithmetic operator '-' + return new DataFusionBaseExpr("-", 2, DataFusionBaseExprCategory.BINARY, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), // arg1 + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))) // arg2 + )); + case FUNC_ABS: + return createCommonNumericFuncSingleArg("ABS"); + case FUNC_ACOS: + return createCommonNumericFuncSingleArg("ACOS"); + case FUNC_ACOSH: + return createCommonNumericFuncSingleArg("ACOSH"); + case FUNC_ASIN: + return createCommonNumericFuncSingleArg("ASIN"); + case FUNC_ASINH: + return createCommonNumericFuncSingleArg("ASINH"); + case FUNC_ATAN: + return createCommonNumericFuncSingleArg("ATAN"); + case FUNC_ATANH: + return createCommonNumericFuncSingleArg("ATANH"); + case FUNC_ATAN2: + return createCommonNumericFuncTwoArgs("ATAN2"); + case FUNC_CBRT: + return createCommonNumericFuncSingleArg("CBRT"); + case FUNC_CEIL: + return createCommonNumericFuncSingleArg("CEIL"); + case FUNC_COS: + return createCommonNumericFuncSingleArg("COS"); + case FUNC_COSH: + return createCommonNumericFuncSingleArg("COSH"); + case FUNC_DEGREES: + return createCommonNumericFuncSingleArg("DEGREES"); + case FUNC_EXP: + return createCommonNumericFuncSingleArg("EXP"); + case FUNC_FACTORIAL: + return createCommonNumericFuncSingleArg("FACTORIAL"); + case FUNC_FLOOR: + return createCommonNumericFuncSingleArg("FLOOR"); + case FUNC_GCD: + return new DataFusionBaseExpr("GCD", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case FUNC_ISNAN: + return createCommonNumericFuncSingleArg("ISNAN"); + case FUNC_ISZERO: + return createCommonNumericFuncSingleArg("ISZERO"); + case FUNC_LCM: + return createCommonNumericFuncTwoArgs("LCM"); + case FUNC_LN: + return createCommonNumericFuncSingleArg("LN"); + case FUNC_LOG: + return createCommonNumericFuncSingleArg("LOG"); + case FUNC_LOG_WITH_BASE: + return createCommonNumericFuncTwoArgs("LOG"); + case FUNC_LOG10: + return createCommonNumericFuncSingleArg("LOG10"); + case FUNC_LOG2: + return createCommonNumericFuncSingleArg("LOG2"); + case FUNC_NANVL: + return createCommonNumericFuncTwoArgs("NANVL"); + case FUNC_PI: + return new DataFusionBaseExpr("PI", 0, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), Arrays.asList()); + case FUNC_POW: + return createCommonNumericFuncSingleArg("POW"); + case FUNC_POWER: + return createCommonNumericFuncSingleArg("POWER"); + case FUNC_RADIANS: + return createCommonNumericFuncSingleArg("RADIANS"); + case FUNC_ROUND: + return createCommonNumericFuncSingleArg("ROUND"); + case FUNC_ROUND_WITH_DECIMAL: + return new DataFusionBaseExpr("ROUND", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_SIGNUM: + return createCommonNumericFuncSingleArg("SIGNUM"); + case FUNC_SIN: + return createCommonNumericFuncSingleArg("SIN"); + case FUNC_SINH: + return createCommonNumericFuncSingleArg("SINH"); + case FUNC_SQRT: + return createCommonNumericFuncSingleArg("SQRT"); + case FUNC_TAN: + return createCommonNumericFuncSingleArg("TAN"); + case FUNC_TANH: + return createCommonNumericFuncSingleArg("TANH"); + case FUNC_TRUNC: + return createCommonNumericFuncSingleArg("TRUNC"); + case FUNC_TRUNC_WITH_DECIMAL: + return new DataFusionBaseExpr("TRUNC", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>( + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BIGINT))))); + case FUNC_COALESCE: + return new DataFusionBaseExpr("COALESCE", -1, // overide by variadic + DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), Arrays.asList(), true); + case FUNC_NULLIF: + return new DataFusionBaseExpr("NULLIF", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case FUNC_NVL: + return new DataFusionBaseExpr("NVL", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case FUNC_NVL2: + return new DataFusionBaseExpr("NVL2", 3, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + case FUNC_IFNULL: + return new DataFusionBaseExpr("IFNULL", 2, DataFusionBaseExprCategory.FUNC, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList( + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))), + new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE))))); + + case AGGR_MIN: + return createCommonNumericAggrFuncSingleArg("MIN"); + case AGGR_MAX: + return createCommonNumericAggrFuncSingleArg("MAX"); + case AGGR_AVG: + return createCommonNumericAggrFuncSingleArg("AVG"); + case AGGR_SUM: + return createCommonNumericAggrFuncSingleArg("SUM"); + case AGGR_COUNT: + return new DataFusionBaseExpr("COUNT", -1, DataFusionBaseExprCategory.AGGREGATE, + Arrays.asList(DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE), + Arrays.asList(new ArgumentType.Fixed(new ArrayList<>(Arrays.asList(DataFusionDataType.BOOLEAN, + DataFusionDataType.BIGINT, DataFusionDataType.DOUBLE)))), + true); + default: + dfAssert(false, "Unreachable. Unimplemented branch for type " + type); + } + + dfAssert(false, "Unreachable. Unimplemented branch for type " + type); + return null; + } + + // if input is Optional.empty(), return all possible `DataFusionBaseExpr`s + // else, return all `DataFusionBaseExpr` which might be evaluated to arg's type + public static List getExprsWithReturnType(Optional dataTypeOptional) { + List allExpressions = Arrays.stream(DataFusionBaseExprType.values()) + .map(DataFusionBaseExprFactory::createExpr).collect(Collectors.toList()); + + if (!dataTypeOptional.isPresent()) { + return allExpressions; // If Optional is empty, return all expressions + } + + DataFusionDataType filterType = dataTypeOptional.get(); + List exprsWithReturnType = allExpressions.stream() + .filter(expr -> expr.possibleReturnTypes.contains(filterType)).collect(Collectors.toList()); + + if (Randomly.getBoolean()) { + // Too many similar function, so test them less often + return exprsWithReturnType; + } + + return exprsWithReturnType.stream().filter(expr -> expr.exprType != DataFusionBaseExprCategory.FUNC) + .collect(Collectors.toList()); + } + + public static DataFusionBaseExpr getRandomAggregateExpr() { + List allAggrExpressions = Arrays.stream(DataFusionBaseExprType.values()) + .map(DataFusionBaseExprFactory::createExpr) + .filter(expr -> expr.exprType == DataFusionBaseExprCategory.AGGREGATE).collect(Collectors.toList()); + + return Randomly.fromList(allAggrExpressions); + } +} diff --git a/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java b/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java new file mode 100644 index 000000000..520043fa1 --- /dev/null +++ b/src/sqlancer/datafusion/gen/DataFusionExpressionGenerator.java @@ -0,0 +1,305 @@ +package sqlancer.datafusion.gen; + +import static sqlancer.datafusion.DataFusionUtil.dfAssert; +import static sqlancer.datafusion.gen.DataFusionBaseExprFactory.createExpr; +import static sqlancer.datafusion.gen.DataFusionBaseExprFactory.getExprsWithReturnType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.gen.TypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionSchema.DataFusionColumn; +import sqlancer.datafusion.DataFusionSchema.DataFusionDataType; +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; +import sqlancer.datafusion.DataFusionToStringVisitor; +import sqlancer.datafusion.ast.DataFusionBinaryOperation; +import sqlancer.datafusion.ast.DataFusionColumnReference; +import sqlancer.datafusion.ast.DataFusionExpression; +import sqlancer.datafusion.ast.DataFusionFunction; +import sqlancer.datafusion.ast.DataFusionJoin; +import sqlancer.datafusion.ast.DataFusionSelect; +import sqlancer.datafusion.ast.DataFusionTableReference; +import sqlancer.datafusion.ast.DataFusionUnaryPostfixOperation; +import sqlancer.datafusion.ast.DataFusionUnaryPrefixOperation; +import sqlancer.datafusion.gen.DataFusionBaseExpr.ArgumentType; +import sqlancer.datafusion.gen.DataFusionBaseExpr.DataFusionBaseExprType; + +public final class DataFusionExpressionGenerator + extends TypedExpressionGenerator implements + NoRECGenerator, + TLPWhereGenerator { + + private List tables; + private final DataFusionGlobalState globalState; + + public DataFusionExpressionGenerator(DataFusionGlobalState globalState) { + this.globalState = globalState; + } + + @Override + protected DataFusionDataType getRandomType() { + DataFusionDataType dt; + do { + dt = Randomly.fromOptions(DataFusionDataType.values()); + } while (dt == DataFusionDataType.NULL); + + return dt; + } + + @Override + protected boolean canGenerateColumnOfType(DataFusionDataType type) { + return true; + } + + @Override + protected DataFusionExpression generateExpression(DataFusionDataType type, int depth) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + DataFusionDataType expectedType = type; + if (Randomly.getBooleanWithRatherLowProbability()) { // ~10% + expectedType = DataFusionDataType.getRandomWithoutNull(); + } + return generateLeafNode(expectedType); + } + + // nested aggregate is not allowed, so occasionally apply it + Boolean includeAggr = Randomly.getBooleanWithSmallProbability(); + List possibleBaseExprs = getExprsWithReturnType(Optional.of(type)).stream() + // Conditinally apply filter if `includeAggr` set to false + .filter(expr -> includeAggr || expr.exprType != DataFusionBaseExpr.DataFusionBaseExprCategory.AGGREGATE) + .collect(Collectors.toList()); + + if (possibleBaseExprs.isEmpty()) { + dfAssert(type == DataFusionDataType.NULL, "should able to generate expression with type " + type); + return generateLeafNode(type); + } + + DataFusionBaseExpr randomExpr = Randomly.fromList(possibleBaseExprs); + switch (randomExpr.exprType) { + case UNARY_PREFIX: + DataFusionDataType argType = null; + dfAssert(randomExpr.argTypes.size() == 1 && randomExpr.nArgs == 1, + "Unary expression should only have 1 argument" + randomExpr.argTypes); + if (randomExpr.argTypes.get(0) instanceof ArgumentType.Fixed) { + ArgumentType.Fixed possibleArgTypes = (ArgumentType.Fixed) randomExpr.argTypes.get(0); + argType = Randomly.fromList(possibleArgTypes.fixedType); + } else { + argType = type; + } + + return new DataFusionUnaryPrefixOperation(generateExpression(argType, depth + 1), randomExpr); + case UNARY_POSTFIX: + dfAssert(randomExpr.argTypes.size() == 1 && randomExpr.nArgs == 1, + "Unary expression should only have 1 argument" + randomExpr.argTypes); + if (randomExpr.argTypes.get(0) instanceof ArgumentType.Fixed) { + ArgumentType.Fixed possibleArgTypes = (ArgumentType.Fixed) randomExpr.argTypes.get(0); + argType = Randomly.fromList(possibleArgTypes.fixedType); + } else { + argType = type; + } + + return new DataFusionUnaryPostfixOperation(generateExpression(argType, depth + 1), randomExpr); + case BINARY: + dfAssert(randomExpr.argTypes.size() == 2 && randomExpr.nArgs == 2, + "Binrary expression should only have 2 argument" + randomExpr.argTypes); + List argTypeList = new ArrayList<>(); // types of current expression's input + // arguments + for (ArgumentType argumentType : randomExpr.argTypes) { + if (argumentType instanceof ArgumentType.Fixed) { + ArgumentType.Fixed possibleArgTypes = (ArgumentType.Fixed) randomExpr.argTypes.get(0); + dfAssert(!possibleArgTypes.fixedType.isEmpty(), "possible types can't be an empty list"); + DataFusionDataType determinedType = Randomly.fromList(possibleArgTypes.fixedType); + argTypeList.add(determinedType); + } else if (argumentType instanceof ArgumentType.SameAsFirstArgType) { + dfAssert(!argTypeList.isEmpty(), "First argument can't have argument type `SameAsFirstArgType`"); + DataFusionDataType firstArgType = argTypeList.get(0); + argTypeList.add(firstArgType); + } else { + // Same as expression return type + argTypeList.add(type); + } + } + + return new DataFusionBinaryOperation(generateExpression(argTypeList.get(0), depth + 1), + generateExpression(argTypeList.get(1), depth + 1), randomExpr); + case AGGREGATE: + // Fall through + case FUNC: + return generateFunctionExpression(type, depth, randomExpr); + default: + dfAssert(false, "unreachable"); + } + + dfAssert(false, "unreachable"); + return null; + } + + public DataFusionExpression generateFunctionExpression(DataFusionDataType type, int depth, + DataFusionBaseExpr exprType) { + if (exprType.isVariadic || Randomly.getBooleanWithSmallProbability()) { + // TODO(datafusion) maybe add possible types. e.g. some function have signature + // variadic(INT/DOUBLE), then + // only randomly pick from INT and DOUBLE + int nArgs = Randomly.smallNumber(); // 0, 2, 4, ... smaller one is more likely + return new DataFusionFunction(generateExpressions(nArgs), exprType); + } + + List funcArgTypeList = new ArrayList<>(); // types of current expression's input arguments + int i = 0; + for (ArgumentType argumentType : exprType.argTypes) { + if (argumentType instanceof ArgumentType.Fixed) { + ArgumentType.Fixed possibleArgTypes = (ArgumentType.Fixed) exprType.argTypes.get(i); + dfAssert(!possibleArgTypes.fixedType.isEmpty(), "possible types can't be an empty list"); + DataFusionDataType determinedType = Randomly.fromList(possibleArgTypes.fixedType); + funcArgTypeList.add(determinedType); + } else if (argumentType instanceof ArgumentType.SameAsFirstArgType) { + dfAssert(!funcArgTypeList.isEmpty(), "First argument can't have argument type `SameAsFirstArgType`"); + DataFusionDataType firstArgType = funcArgTypeList.get(0); + funcArgTypeList.add(firstArgType); + } else { + // Same as expression return type + funcArgTypeList.add(type); + } + i++; + } + + List argExpressions = new ArrayList<>(); + + for (DataFusionDataType dataType : funcArgTypeList) { + argExpressions.add(generateExpression(dataType, depth + 1)); + } + + return new DataFusionFunction(argExpressions, exprType); + } + + List filterColumns(DataFusionDataType type) { + if (columns == null) { + return Collections.emptyList(); + } else { + return columns.stream().filter(c -> c.getType() == type).collect(Collectors.toList()); + } + } + + @Override + protected DataFusionExpression generateColumn(DataFusionDataType type) { + // HACK: if no col of such type exist, generate constant value instead + List colsOfType = filterColumns(type); + if (colsOfType.isEmpty()) { + return generateConstant(type); + } + + DataFusionColumn column = Randomly.fromList(colsOfType); + return new DataFusionColumnReference(column); + } + + @Override + public DataFusionExpression generateConstant(DataFusionDataType type) { + return type.getRandomConstant(globalState); + } + + @Override + public DataFusionExpression generatePredicate() { + return generateExpression(DataFusionDataType.BOOLEAN, 0); + } + + @Override + public DataFusionExpression negatePredicate(DataFusionExpression predicate) { + return new DataFusionUnaryPrefixOperation(predicate, createExpr(DataFusionBaseExprType.NOT)); + } + + @Override + public DataFusionExpression isNull(DataFusionExpression expr) { + return new DataFusionUnaryPostfixOperation(expr, createExpr(DataFusionBaseExprType.IS_NULL)); + } + + public static class DataFusionCastOperation extends NewUnaryPostfixOperatorNode { + + public DataFusionCastOperation(DataFusionExpression expr, DataFusionDataType type) { + super(expr, new Operator() { + + @Override + public String getTextRepresentation() { + return "::" + type.toString(); + } + }); + } + + } + + @Override + public DataFusionExpressionGenerator setTablesAndColumns(AbstractTables tables) { + List randomTables = Randomly.nonEmptySubset(tables.getTables()); + int maxSize = Randomly.fromOptions(1, 2, 3, 4); + if (randomTables.size() > maxSize) { + randomTables = randomTables.subList(0, maxSize); + } + this.columns = DataFusionTable.getAllColumns(randomTables); + this.tables = randomTables; + + return this; + } + + @Override + public DataFusionExpression generateBooleanExpression() { + return generateExpression(DataFusionDataType.BOOLEAN); + } + + @Override + public DataFusionSelect generateSelect() { + return new DataFusionSelect(); + } + + @Override + public List getRandomJoinClauses() { + List tableList = tables.stream().map(t -> new DataFusionTableReference(t)) + .collect(Collectors.toList()); + List joins = DataFusionJoin.getJoins(tableList, globalState); + tables = tableList.stream().map(t -> t.getTable()).collect(Collectors.toList()); + return joins; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new DataFusionTableReference(t)).collect(Collectors.toList()); + } + + @Override + public String generateOptimizedQueryString(DataFusionSelect select, DataFusionExpression whereCondition, + boolean shouldUseAggregate) { + if (shouldUseAggregate) { + select.setFetchColumnsString("COUNT(*)"); + } else { + List allColumns = columns.stream().map((c) -> new DataFusionColumnReference(c)) + .collect(Collectors.toList()); + select.setFetchColumns(allColumns); + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(DataFusionSelect select, DataFusionExpression whereCondition) { + String fetchColumn = String.format("COUNT(CASE WHEN %S THEN 1 ELSE NULL END)", + DataFusionToStringVisitor.asString(whereCondition)); + select.setFetchColumnsString(fetchColumn); + select.setWhereClause(null); + + return select.asString(); + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + List randomColumns = DataFusionTable.getRandomColumns(tables); + return randomColumns.stream().map((c) -> new DataFusionColumnReference(c)).collect(Collectors.toList()); + } +} diff --git a/src/sqlancer/datafusion/gen/DataFusionInsertGenerator.java b/src/sqlancer/datafusion/gen/DataFusionInsertGenerator.java new file mode 100644 index 000000000..1ee00dd50 --- /dev/null +++ b/src/sqlancer/datafusion/gen/DataFusionInsertGenerator.java @@ -0,0 +1,54 @@ +package sqlancer.datafusion.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.common.gen.AbstractInsertGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionSchema.DataFusionColumn; +import sqlancer.datafusion.DataFusionSchema.DataFusionTable; +import sqlancer.datafusion.DataFusionToStringVisitor; + +public class DataFusionInsertGenerator extends AbstractInsertGenerator { + + private final DataFusionGlobalState globalState; + private final ExpectedErrors errors = new ExpectedErrors(); + + public DataFusionInsertGenerator(DataFusionGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter getQuery(DataFusionGlobalState globalState, DataFusionTable targetTable) { + return new DataFusionInsertGenerator(globalState).generate(targetTable); + } + + private SQLQueryAdapter generate(DataFusionTable targetTable) { + // `sb` is a global `StringBuilder` for current insert query + sb.append("INSERT INTO "); + + if (targetTable.getColumns().isEmpty()) { + throw new IgnoreMeException(); + } + List columns = targetTable.getRandomNonEmptyColumnSubset(); + + sb.append(targetTable.getName()); + sb.append("("); + sb.append(columns.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + sb.append(")"); + sb.append(" VALUES "); + insertColumns(columns); // will finally call `insertValue()` to generate random value + + return new SQLQueryAdapter(sb.toString(), errors); + } + + @Override + protected void insertValue(DataFusionColumn col) { + String val = DataFusionToStringVisitor + .asString(new DataFusionExpressionGenerator(globalState).generateConstant(col.getType())); + sb.append(val); + } + +} diff --git a/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java b/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java new file mode 100644 index 000000000..adececaa7 --- /dev/null +++ b/src/sqlancer/datafusion/gen/DataFusionTableGenerator.java @@ -0,0 +1,33 @@ +package sqlancer.datafusion.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.datafusion.DataFusionProvider.DataFusionGlobalState; +import sqlancer.datafusion.DataFusionSchema.DataFusionDataType; + +public class DataFusionTableGenerator { + + // Randomly generate a query like 'create table t1 (v1 bigint, v2 boolean)' + public SQLQueryAdapter getQuery(DataFusionGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + String tableName = globalState.getSchema().getFreeTableName(); + sb.append("CREATE TABLE "); + sb.append(tableName); + sb.append("("); + + int colCount = Randomly.smallNumber() + 1 + (Randomly.getBoolean() ? 1 : 0); + for (int i = 0; i < colCount; i++) { + sb.append("v").append(i).append(" ").append(DataFusionDataType.getRandomWithoutNull().toString()); + + if (i != colCount - 1) { + sb.append(", "); + } + } + + sb.append(");"); + + return new SQLQueryAdapter(sb.toString(), errors, true); + } +} diff --git a/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml b/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml new file mode 100644 index 000000000..332a88e30 --- /dev/null +++ b/src/sqlancer/datafusion/server/datafusion_server/Cargo.toml @@ -0,0 +1,46 @@ +[package] +name = "datafusion-server" +edition = "2021" +description = "Standalone DataFusion server" +license = "Apache-2.0" + +[dependencies] +ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +arrow = { version = "52.1.0", features = ["prettyprint"] } +arrow-array = { version = "52.1.0", default-features = false, features = ["chrono-tz"] } +arrow-buffer = { version = "52.1.0", default-features = false } +arrow-flight = { version = "52.1.0", features = ["flight-sql-experimental"] } +arrow-ipc = { version = "52.1.0", default-features = false, features = ["lz4"] } +arrow-ord = { version = "52.1.0", default-features = false } +arrow-schema = { version = "52.1.0", default-features = false } +arrow-string = { version = "52.1.0", default-features = false } +async-trait = "0.1.73" +bytes = "1.4" +chrono = { version = ">=0.4.34, <0.4.40", default-features = false } +dashmap = "5.5.0" +# This version is for SQLancer CI run +datafusion = { version = "40.0.0" } +# Use following line if you want to test against the latest main branch of DataFusion +# datafusion = { git = "https://github.com/apache/datafusion.git", branch = "main" } +env_logger = "0.11" +futures = "0.3" +half = { version = "2.2.1", default-features = false } +hashbrown = { version = "0.14.5", features = ["raw"] } +log = "0.4" +num_cpus = "1.13.0" +object_store = { version = "0.10.1", default-features = false } +parking_lot = "0.12" +parquet = { version = "52.0.0", default-features = false, features = ["arrow", "async", "object_store"] } +rand = "0.8" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1" +tokio = { version = "1.36", features = ["macros", "rt", "sync"] } +tonic = "0.11" +uuid = "1.0" +prost = { version = "0.12", default-features = false } +prost-derive = { version = "0.12", default-features = false } +mimalloc = { version = "0.1", default-features = false } + +[[bin]] +name = "datafusion-server" +path = "src/main.rs" \ No newline at end of file diff --git a/src/sqlancer/datafusion/server/datafusion_server/src/main.rs b/src/sqlancer/datafusion/server/datafusion_server/src/main.rs new file mode 100644 index 000000000..057c34883 --- /dev/null +++ b/src/sqlancer/datafusion/server/datafusion_server/src/main.rs @@ -0,0 +1,463 @@ +use arrow::array::{ArrayRef, StringArray}; +use arrow::ipc::writer::IpcWriteOptions; +use arrow::record_batch::RecordBatch; +use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::flight_descriptor::DescriptorType; +use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; +use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream}; +use arrow_flight::sql::{ + ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, + ActionCreatePreparedStatementResult, Any, CommandGetTables, CommandPreparedStatementQuery, + CommandPreparedStatementUpdate, ProstMessageExt, SqlInfo, +}; +use arrow_flight::{ + Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, + IpcMessage, SchemaAsIpc, Ticket, +}; +use arrow_schema::{DataType, Field, Schema}; +use dashmap::DashMap; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::{DataFrame, ParquetReadOptions, SessionConfig, SessionContext}; +use futures::{Stream, StreamExt, TryStreamExt}; +use log::info; +use mimalloc::MiMalloc; +use prost::Message; +use std::pin::Pin; +use std::sync::Arc; +use tokio::sync::Mutex; +use tonic::metadata::MetadataValue; +use tonic::transport::Server; +use tonic::{Request, Response, Status, Streaming}; +use uuid::Uuid; + +#[global_allocator] +static GLOBAL: MiMalloc = MiMalloc; + +macro_rules! status { + ($desc:expr, $err:expr) => { + Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!())) + }; +} + +/// Adapted from https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/flight/flight_sql_server.rs +/// Can be used as a remote DataFusion server and connected by `JDBC` from client +/// Supported SQL statements: +/// CREATE +/// INSERT +/// SELECT +/// +/// Only single client is supported +/// For now use `ctx` instead of `contexts` inside `FlightSqlServiceImpl` +/// +/// === Below origianl comment === +/// +/// This example shows how to wrap DataFusion with `FlightSqlService` to support connecting +/// to a standalone DataFusion-based server with a JDBC client, using the open source "JDBC Driver +/// for Arrow Flight SQL". +/// +/// To install the JDBC driver in DBeaver for example, see these instructions: +/// https://docs.dremio.com/software/client-applications/dbeaver/ +/// When configuring the driver, specify property "UseEncryption" = false +/// +/// JDBC connection string: "jdbc:arrow-flight-sql://127.0.0.1:50051/" +/// +/// Based heavily on Ballista's implementation: https://github.com/apache/datafusion-ballista/blob/main/ballista/scheduler/src/flight_sql.rs +/// and the example in arrow-rs: https://github.com/apache/arrow-rs/blob/master/arrow-flight/examples/flight_sql_server.rs +/// +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + let addr = "0.0.0.0:50051".parse()?; + let session_ctx = SessionContext::new_with_config( + SessionConfig::new().with_information_schema(true), // enable catalog + ); + let service = FlightSqlServiceImpl { + contexts: Default::default(), + statements: Default::default(), + results: Default::default(), + ctx: Arc::new(Mutex::new(session_ctx)), + }; + info!("Listening on {addr:?}"); + let svc = FlightServiceServer::new(service); + + Server::builder().add_service(svc).serve(addr).await?; + + Ok(()) +} + +pub struct FlightSqlServiceImpl { + contexts: Arc>>, + statements: Arc>, + results: Arc>>, + ctx: Arc>, +} + +impl FlightSqlServiceImpl { + async fn create_ctx(&self) -> Result { + let uuid = Uuid::new_v4().hyphenated().to_string(); + let session_config = SessionConfig::from_env() + .map_err(|e| Status::internal(format!("Error building plan: {e}")))? + .with_information_schema(true); + let ctx = Arc::new(SessionContext::new_with_config(session_config)); + + self.contexts.insert(uuid.clone(), ctx); + Ok(uuid) + } + + fn get_ctx(&self, req: &Request) -> Result, Status> { + // get the token from the authorization header on Request + let auth = req + .metadata() + .get("authorization") + .ok_or_else(|| Status::internal("No authorization header!"))?; + let str = auth + .to_str() + .map_err(|e| Status::internal(format!("Error parsing header: {e}")))?; + let authorization = str.to_string(); + let bearer = "Bearer "; + if !authorization.starts_with(bearer) { + Err(Status::internal("Invalid auth header!"))?; + } + let auth = authorization[bearer.len()..].to_string(); + + if let Some(context) = self.contexts.get(&auth) { + Ok(context.clone()) + } else { + Err(Status::internal(format!( + "Context handle not found: {auth}" + )))? + } + } + + fn get_plan(&self, handle: &str) -> Result { + if let Some(plan) = self.statements.get(handle) { + Ok(plan.clone()) + } else { + Err(Status::internal(format!("Plan handle not found: {handle}")))? + } + } + + fn get_result(&self, handle: &str) -> Result, Status> { + if let Some(result) = self.results.get(handle) { + Ok(result.clone()) + } else { + Err(Status::internal(format!( + "Request handle not found: {handle}" + )))? + } + } + + async fn tables(&self, ctx: Arc) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, true), + Field::new("db_schema_name", DataType::Utf8, true), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + ])); + + let mut catalogs = vec![]; + let mut schemas = vec![]; + let mut names = vec![]; + let mut types = vec![]; + for catalog in ctx.catalog_names() { + let catalog_provider = ctx.catalog(&catalog).unwrap(); + for schema in catalog_provider.schema_names() { + let schema_provider = catalog_provider.schema(&schema).unwrap(); + for table in schema_provider.table_names() { + let table_provider = schema_provider.table(&table).await.unwrap().unwrap(); + catalogs.push(catalog.clone()); + schemas.push(schema.clone()); + names.push(table.clone()); + types.push(table_provider.table_type().to_string()) + } + } + } + + RecordBatch::try_new( + schema, + [catalogs, schemas, names, types] + .into_iter() + .map(|i| Arc::new(StringArray::from(i)) as ArrayRef) + .collect::>(), + ) + .unwrap() + } + + fn remove_plan(&self, handle: &str) -> Result<(), Status> { + self.statements.remove(&handle.to_string()); + Ok(()) + } + + fn remove_result(&self, handle: &str) -> Result<(), Status> { + self.results.remove(&handle.to_string()); + Ok(()) + } +} + +#[tonic::async_trait] +impl FlightSqlService for FlightSqlServiceImpl { + type FlightService = FlightSqlServiceImpl; + + // This function will be triggered if client JDBC property's `user` and `password` field set + async fn do_handshake( + &self, + _request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + info!("do_handshake"); + if let Some(msg) = _request.metadata().get("create") { + // A new round start at SQLancer, clear the ctx + info!("Resetting ctx {:?}", msg); + let new_ctx = + SessionContext::new_with_config(SessionConfig::new().with_information_schema(true)); + + let mut ctx_guard = self.ctx.lock().await; // Use `lock()` for async Mutex + *ctx_guard = new_ctx; + + // Clear leaked state from previous round + self.statements.clear(); + self.results.clear(); + self.contexts.clear(); + } + // no authentication actually takes place here + // see Ballista implementation for example of basic auth + // in this case, we simply accept the connection and create a new SessionContext + // the SessionContext will be re-used within this same connection/session + let token = self.create_ctx().await?; + + let result = HandshakeResponse { + protocol_version: 0, + payload: token.as_bytes().to_vec().into(), + }; + let result = Ok(result); + let output = futures::stream::iter(vec![result]); + let str = format!("Bearer {token}"); + let mut resp: Response> + Send>>> = + Response::new(Box::pin(output)); + let md = MetadataValue::try_from(str) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + resp.metadata_mut().insert("authorization", md); + Ok(resp) + } + + async fn do_get_fallback( + &self, + _request: Request, + message: Any, + ) -> Result::DoGetStream>, Status> { + if !message.is::() { + Err(Status::unimplemented(format!( + "do_get: The defined request is invalid: {}", + message.type_url + )))? + } + + let fr: FetchResults = message + .unpack() + .map_err(|e| Status::internal(format!("{e:?}")))? + .ok_or_else(|| Status::internal("Expected FetchResults but got None!"))?; + + let handle = fr.handle; + + info!("getting results for {handle}"); + let result = self.get_result(&handle)?; + // if we get an empty result, create an empty schema + let (schema, batches) = match result.first() { + None => (Arc::new(Schema::empty()), vec![]), + Some(batch) => (batch.schema(), result.clone()), + }; + + let batch_stream = futures::stream::iter(batches).map(Ok); + + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(batch_stream) + .map_err(Status::from); + + Ok(Response::new(Box::pin(stream))) + } + + async fn get_flight_info_prepared_statement( + &self, + cmd: CommandPreparedStatementQuery, + _request: Request, + ) -> Result, Status> { + info!("get_flight_info_prepared_statement {:?}", cmd); + let handle = std::str::from_utf8(&cmd.prepared_statement_handle) + .map_err(|e| status!("Unable to parse uuid", e))?; + + //let ctx = self.get_ctx(&request)?; + let plan = self.get_plan(handle)?; + + let ctx_guard = self.ctx.lock().await; + let state = (*ctx_guard).state(); + let df = DataFrame::new(state, plan); + let result = df + .collect() + .await + .map_err(|e| status!("Error executing query", e))?; + + // if we get an empty result, create an empty schema + let schema = match result.first() { + None => Schema::empty(), + Some(batch) => (*batch.schema()).clone(), + }; + + self.results.insert(handle.to_string(), result); + + // if we had multiple endpoints to connect to, we could use this Location + // but in the case of standalone DataFusion, we don't + // let loc = Location { + // uri: "grpc+tcp://127.0.0.1:50051".to_string(), + // }; + let fetch = FetchResults { + handle: handle.to_string(), + }; + let buf = fetch.as_any().encode_to_vec().into(); + let ticket = Ticket { ticket: buf }; + + let info = FlightInfo::new() + // Encode the Arrow schema + .try_with_schema(&schema) + .expect("encoding failed") + .with_endpoint(FlightEndpoint::new().with_ticket(ticket)) + .with_descriptor(FlightDescriptor { + r#type: DescriptorType::Cmd.into(), + cmd: Default::default(), + path: vec![], + }); + let resp = Response::new(info); + Ok(resp) + } + + async fn get_flight_info_tables( + &self, + _query: CommandGetTables, + request: Request, + ) -> Result, Status> { + info!("get_flight_info_tables"); + let ctx = self.get_ctx(&request)?; + let data = self.tables(ctx).await; + let schema = data.schema(); + + let uuid = Uuid::new_v4().hyphenated().to_string(); + self.results.insert(uuid.clone(), vec![data]); + + let fetch = FetchResults { handle: uuid }; + let buf = fetch.as_any().encode_to_vec().into(); + let ticket = Ticket { ticket: buf }; + + let info = FlightInfo::new() + // Encode the Arrow schema + .try_with_schema(&schema) + .expect("encoding failed") + .with_endpoint(FlightEndpoint::new().with_ticket(ticket)) + .with_descriptor(FlightDescriptor { + r#type: DescriptorType::Cmd.into(), + cmd: Default::default(), + path: vec![], + }); + let resp = Response::new(info); + Ok(resp) + } + + async fn do_put_prepared_statement_update( + &self, + handle: CommandPreparedStatementUpdate, + _request: Request, + ) -> Result { + info!("do_put_prepared_statement_update"); + // statements like "CREATE TABLE.." or "SET datafusion.nnn.." call this function + // and we are required to return some row count here + let handle = std::str::from_utf8(&handle.prepared_statement_handle) + .map_err(|e| status!("Unable to parse uuid", e))?; + + //let ctx = self.get_ctx(&request)?; + let plan = self.get_plan(handle)?; + //println!("do_put_prepared_statement_update plan is {:?}", plan); + + let ctx_guard = self.ctx.lock().await; + let state = (*ctx_guard).state(); + let df = DataFrame::new(state, plan); + df.collect() + .await + .map_err(|e| status!("Error executing query", e))?; + + Ok(1) + } + + async fn do_action_create_prepared_statement( + &self, + query: ActionCreatePreparedStatementRequest, + _request: Request, + ) -> Result { + let user_query = query.query.as_str(); + info!("do_action_create_prepared_statement: {user_query}"); + + //let ctx = self.get_ctx(&request)?; + + let ctx_guard = self.ctx.lock().await; + let plan = (*ctx_guard) + .sql(user_query) + .await + .and_then(|df| df.into_optimized_plan()) + .map_err(|e| Status::internal(format!("Error building plan: {e}")))?; + + // store a copy of the plan, it will be used for execution + let plan_uuid = Uuid::new_v4().hyphenated().to_string(); + self.statements.insert(plan_uuid.clone(), plan.clone()); + + let plan_schema = plan.schema(); + + let arrow_schema = (&**plan_schema).into(); + let message = SchemaAsIpc::new(&arrow_schema, &IpcWriteOptions::default()) + .try_into() + .map_err(|e| status!("Unable to serialize schema", e))?; + let IpcMessage(schema_bytes) = message; + + let res = ActionCreatePreparedStatementResult { + prepared_statement_handle: plan_uuid.into(), + dataset_schema: schema_bytes, + parameter_schema: Default::default(), + }; + Ok(res) + } + + async fn do_action_close_prepared_statement( + &self, + handle: ActionClosePreparedStatementRequest, + _request: Request, + ) -> Result<(), Status> { + info!("do_action_close_prepared_statement"); + let handle = std::str::from_utf8(&handle.prepared_statement_handle); + if let Ok(handle) = handle { + info!("do_action_close_prepared_statement: removing plan and results for {handle}"); + let _ = self.remove_plan(handle); + let _ = self.remove_result(handle); + } + Ok(()) + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FetchResults { + #[prost(string, tag = "1")] + pub handle: ::prost::alloc::string::String, +} + +impl ProstMessageExt for FetchResults { + fn type_url() -> &'static str { + "type.googleapis.com/datafusion.example.com.sql.FetchResults" + } + + fn as_any(&self) -> Any { + Any { + type_url: FetchResults::type_url().to_string(), + value: ::prost::Message::encode_to_vec(self).into(), + } + } +} diff --git a/src/sqlancer/doris/DorisBugs.java b/src/sqlancer/doris/DorisBugs.java new file mode 100644 index 000000000..956be1683 --- /dev/null +++ b/src/sqlancer/doris/DorisBugs.java @@ -0,0 +1,45 @@ +package sqlancer.doris; + +public final class DorisBugs { + // https://github.com/apache/doris/issues/19370 + // Internal Error occur in GroupBy&Having sql + // fixed by https://github.com/apache/doris/pull/19559 + public static boolean bug19370; + + // https://github.com/apache/doris/issues/19374 + // Different result of having not ($value in column) and having ($value not in column) + // fixed by https://github.com/apache/doris/pull/19471 + public static boolean bug19374; + + // https://github.com/apache/doris/issues/19611 + // ERROR occur in nested subqueries with same column name and union + public static boolean bug19611 = true; + + // https://github.com/apache/doris/issues/36070 + // Expression evaluate to NULL but is treated as FALSE in where clause + public static boolean bug36070 = true; + + // https://github.com/apache/doris/issues/36072 + // SELECT DISTINCT does not work with aggregate key column + public static boolean bug36072 = true; + + // https://github.com/apache/doris/issues/36342 + // Wrong result with INNER JOIN and CURRENT_TIMESTAMP + public static boolean bug36342 = true; + + // https://github.com/apache/doris/issues/36343 + // Wrong result with SELECT DISTINCT and UNIQUE model + public static boolean bug36343 = true; + + // https://github.com/apache/doris/issues/36346 + // Wrong result with LEFT JOIN SELECT DISTINCT and IN operation + public static boolean bug36346 = true; + + // https://github.com/apache/doris/issues/36351 + // Wrong result with TINYINT column with value -1049190528 + public static boolean bug36351 = true; + + private DorisBugs() { + + } +} diff --git a/src/sqlancer/doris/DorisErrors.java b/src/sqlancer/doris/DorisErrors.java new file mode 100644 index 000000000..28e93b80a --- /dev/null +++ b/src/sqlancer/doris/DorisErrors.java @@ -0,0 +1,85 @@ +package sqlancer.doris; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class DorisErrors { + + private DorisErrors() { + } + + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + // SQL syntax error + errors.add("Syntax error"); + errors.add("Please check your sql, we meet an error when parsing"); + errors.add("but returns type"); + errors.add("is not a number"); + + // Not in line with Doris' logic + errors.add("Unexpected exception: null"); + errors.add("Cross join can't be used with ON clause"); + errors.add("BetweenPredicate needs to be rewritten into a CompoundPredicate"); + errors.add("can't be assigned to some PlanNode"); + errors.add("can not cast from origin type"); + errors.add("not produced by aggregation output"); + errors.add("cannot combine"); // cannot combine SELECT DISTINCT with aggregate functions or GROUP BY + errors.add("Invalid type"); + errors.add("cannot be cast to"); + + // functions + errors.add("No matching function with signature"); + errors.add("Invalid number format"); + errors.add("group_concat requires"); + errors.add("function's argument should be"); + errors.add("requires a numeric parameter"); + errors.add("out of bounds"); + errors.add("function do not support"); + errors.add("parameter must be"); + errors.add("Not supported input arguments types"); + errors.add("No matching function with signature"); + errors.add("function"); + errors.add("Invalid"); + errors.add("Incorrect"); + + // regex + + // To avoid bugs + if (DorisBugs.bug19370) { + errors.add("failed to initialize storage"); + } + if (DorisBugs.bug19374) { + errors.add("the size of the result sets mismatch"); + } + if (DorisBugs.bug19611) { + errors.add("Duplicated inline view column alias"); + } + errors.add("Arithmetic overflow"); + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("Insert has filtered data in strict mode"); + errors.add("Only value columns of unique table could be updated"); + errors.add("Only unique olap table could be updated"); + errors.add("Number out of range"); + errors.add("Arithmetic overflow"); + + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); + } + +} diff --git a/src/sqlancer/doris/DorisOptions.java b/src/sqlancer/doris/DorisOptions.java new file mode 100644 index 000000000..7e5a2c7e3 --- /dev/null +++ b/src/sqlancer/doris/DorisOptions.java @@ -0,0 +1,106 @@ +package sqlancer.doris; + +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; + +@Parameters(commandDescription = "Apache Doris (default port: " + DorisOptions.DEFAULT_PORT + ", default host: " + + DorisOptions.DEFAULT_HOST + ")") +public class DorisOptions implements DBMSSpecificOptions { + public static final String DEFAULT_HOST = "localhost"; + public static final int DEFAULT_PORT = 9030; + + @Parameter(names = { "--max-num-tables" }, description = "The maximum number of tables/views that can be created") + public int maxNumTables = 10; + + @Parameter(names = { "--max-num-indexes" }, description = "The maximum number of indexes that can be created") + public int maxNumIndexes = 20; + + @Parameter(names = "--test-default-values", description = "Allow generating DEFAULT values in tables", arity = 1) + public boolean testDefaultValues = true; + + @Parameter(names = "--test-not-null", description = "Allow generating NOT NULL constraints in tables", arity = 1) + public boolean testNotNullConstraints = true; + + @Parameter(names = "--test-functions", description = "Allow generating functions in expressions", arity = 1) + public boolean testFunctions; + + @Parameter(names = "--test-casts", description = "Allow generating casts in expressions", arity = 1) + public boolean testCasts = true; + + @Parameter(names = "--test-between", description = "Allow generating the BETWEEN operator in expressions", arity = 1) + public boolean testBetween = true; + + @Parameter(names = "--test-in", description = "Allow generating the IN operator in expressions", arity = 1) + public boolean testIn = true; + + @Parameter(names = "--test-case", description = "Allow generating the CASE operator in expressions", arity = 1) + public boolean testCase = true; + + @Parameter(names = "--test-binary-logicals", description = "Allow generating AND and OR in expressions", arity = 1) + public boolean testBinaryLogicals = true; + + @Parameter(names = "--test-int-constants", description = "Allow generating INTEGER constants", arity = 1) + public boolean testIntConstants = true; + + @Parameter(names = "--test-float-constants", description = "Allow generating floating-point constants", arity = 1) + public boolean testFloatConstants = true; + + @Parameter(names = "--test-decimal-constants", description = "Allow generating DECIMAL constants", arity = 1) + public boolean testDecimalConstants = true; + + @Parameter(names = "--test-date-constants", description = "Allow generating DATE constants", arity = 1) + public boolean testDateConstants = true; + + @Parameter(names = "--test-datetime-constants", description = "Allow generating DATETIME constants", arity = 1) + public boolean testDateTimeConstants = true; + + @Parameter(names = "--test-varchar-constants", description = "Allow generating VARCHAR constants", arity = 1) + public boolean testStringConstants = true; + + @Parameter(names = "--test-boolean-constants", description = "Allow generating boolean constants", arity = 1) + public boolean testBooleanConstants = true; + + @Parameter(names = "--test-binary-comparisons", description = "Allow generating binary comparison operators (e.g., >= or LIKE)", arity = 1) + public boolean testBinaryComparisons = true; + + @Parameter(names = "--max-num-deletes", description = "The maximum number of DELETE statements that are issued for a database", arity = 1) + public int maxNumDeletes = 1; + + @Parameter(names = "--max-num-updates", description = "The maximum number of UPDATE statements that are issued for a database", arity = 1) + public int maxNumUpdates; + + @Parameter(names = "--max-num-table-alters", description = "The maximum number of ALTER TABLE statements that are issued for a database", arity = 1) + public int maxNumTableAlters; + + @Parameter(names = "--test-engine-type", description = "The engine type in Doris, only consider OLAP now", arity = 1) + public String testEngineType = "OLAP"; // skip now + + @Parameter(names = "--test-indexes", description = "Allow explicit indexes, Doris only supports creating indexes on single-column BITMAP", arity = 1) + public boolean testIndexes = true; // skip now + + @Parameter(names = "--test-column-aggr", description = "Allow test column aggregation (sum, min, max, replace, replace_if_not_null, hll_union, bitmap_untion)", arity = 1) + public boolean testColumnAggr = true; + + @Parameter(names = "--test-datemodel", description = "Allow generating Doris’s data model in tables. (Aggregate、Uniqe、Duplicate)", arity = 1) + public boolean testDataModel = true; + + @Parameter(names = "--test-distribution", description = "Allow generating data distribution in tables.", arity = 1) + public boolean testDistribution = true; // must have it, skip now + + @Parameter(names = "--test-rollup", description = "Allow generating rollups in tables.", arity = 1) + public boolean testRollup = true; // skip now + + @Parameter(names = "--oracle") + public List oracles = Arrays.asList(DorisOracleFactory.NOREC); + + @Override + public List getTestOracleFactory() { + return oracles; + } + +} diff --git a/src/sqlancer/doris/DorisOracleFactory.java b/src/sqlancer/doris/DorisOracleFactory.java new file mode 100644 index 000000000..8f28f0f21 --- /dev/null +++ b/src/sqlancer/doris/DorisOracleFactory.java @@ -0,0 +1,108 @@ +package sqlancer.doris; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.doris.gen.DorisNewExpressionGenerator; +import sqlancer.doris.oracle.DorisPivotedQuerySynthesisOracle; +import sqlancer.doris.oracle.tlp.DorisQueryPartitioningAggregateTester; +import sqlancer.doris.oracle.tlp.DorisQueryPartitioningDistinctTester; +import sqlancer.doris.oracle.tlp.DorisQueryPartitioningGroupByTester; +import sqlancer.doris.oracle.tlp.DorisQueryPartitioningHavingTester; + +public enum DorisOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create(DorisProvider.DorisGlobalState globalState) + throws SQLException { + DorisNewExpressionGenerator gen = new DorisNewExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(DorisErrors.getExpressionErrors()) + .with("canceling statement due to statement timeout").build(); + return new NoRECOracle<>(globalState, gen, errors); + } + + }, + HAVING { + @Override + public TestOracle create(DorisProvider.DorisGlobalState globalState) + throws SQLException { + return new DorisQueryPartitioningHavingTester(globalState); + } + }, + WHERE { + @Override + public TestOracle create(DorisProvider.DorisGlobalState globalState) + throws SQLException { + DorisNewExpressionGenerator gen = new DorisNewExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(DorisErrors.getExpressionErrors()) + .with(DorisErrors.getExpressionErrors()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + GROUP_BY { + @Override + public TestOracle create(DorisProvider.DorisGlobalState globalState) + throws SQLException { + return new DorisQueryPartitioningGroupByTester(globalState); + } + }, + AGGREGATE { + @Override + public TestOracle create(DorisProvider.DorisGlobalState globalState) + throws SQLException { + return new DorisQueryPartitioningAggregateTester(globalState); + } + + }, + DISTINCT { + @Override + public TestOracle create(DorisProvider.DorisGlobalState globalState) + throws SQLException { + return new DorisQueryPartitioningDistinctTester(globalState); + } + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(DorisProvider.DorisGlobalState globalState) + throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(WHERE.create(globalState)); + oracles.add(HAVING.create(globalState)); + oracles.add(AGGREGATE.create(globalState)); + oracles.add(DISTINCT.create(globalState)); + oracles.add(GROUP_BY.create(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + }, + PQS { + @Override + public TestOracle create(DorisProvider.DorisGlobalState globalState) + throws Exception { + return new DorisPivotedQuerySynthesisOracle(globalState); + } + }, + ALL { + @Override + public TestOracle create(DorisProvider.DorisGlobalState globalState) + throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(NOREC.create(globalState)); + oracles.add(WHERE.create(globalState)); + oracles.add(HAVING.create(globalState)); + oracles.add(AGGREGATE.create(globalState)); + oracles.add(DISTINCT.create(globalState)); + oracles.add(GROUP_BY.create(globalState)); + oracles.add(new DorisPivotedQuerySynthesisOracle(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + } + +} diff --git a/src/sqlancer/doris/DorisProvider.java b/src/sqlancer/doris/DorisProvider.java new file mode 100644 index 000000000..3f64231f1 --- /dev/null +++ b/src/sqlancer/doris/DorisProvider.java @@ -0,0 +1,155 @@ +package sqlancer.doris; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.gen.DorisAlterTableGenerator; +import sqlancer.doris.gen.DorisDeleteGenerator; +import sqlancer.doris.gen.DorisDropTableGenerator; +import sqlancer.doris.gen.DorisDropViewGenerator; +import sqlancer.doris.gen.DorisIndexGenerator; +import sqlancer.doris.gen.DorisInsertGenerator; +import sqlancer.doris.gen.DorisTableGenerator; +import sqlancer.doris.gen.DorisUpdateGenerator; +import sqlancer.doris.gen.DorisViewGenerator; + +@AutoService(DatabaseProvider.class) +public class DorisProvider extends SQLProviderAdapter { + + public DorisProvider() { + super(DorisGlobalState.class, DorisOptions.class); + } + + public enum Action implements AbstractAction { + CREATE_TABLE(DorisTableGenerator::createRandomTableStatement), CREATE_VIEW(DorisViewGenerator::getQuery), + CREATE_INDEX(DorisIndexGenerator::getQuery), INSERT(DorisInsertGenerator::getQuery), + DELETE(DorisDeleteGenerator::generate), UPDATE(DorisUpdateGenerator::getQuery), + ALTER_TABLE(DorisAlterTableGenerator::getQuery), + TRUNCATE((g) -> new SQLQueryAdapter( + "TRUNCATE TABLE " + g.getSchema().getRandomTable(t -> !t.isView()).getName())), + DROP_TABLE(DorisDropTableGenerator::dropTable), DROP_VIEW(DorisDropViewGenerator::dropView); + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(DorisGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + + private static int mapActions(DorisGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + switch (a) { + case INSERT: + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + case DELETE: + return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumDeletes); + case UPDATE: + return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumUpdates); + case ALTER_TABLE: + return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumTableAlters); + case TRUNCATE: + return r.getInteger(0, 2); + case CREATE_TABLE: + case CREATE_INDEX: + case CREATE_VIEW: + case DROP_TABLE: + case DROP_VIEW: + return 0; + default: + throw new AssertionError(a); + } + } + + public static class DorisGlobalState extends SQLGlobalState { + + @Override + protected DorisSchema readSchema() throws SQLException { + return DorisSchema.fromConnection(getConnection(), getDatabaseName()); + } + + } + + @Override + public void generateDatabase(DorisGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { + boolean success = false; + do { + SQLQueryAdapter qt = new DorisTableGenerator().getQuery(globalState); + if (qt != null) { + success = globalState.executeStatement(qt); + } + } while (!success); + } + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + DorisProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public SQLConnection createDatabase(DorisGlobalState globalState) throws SQLException { + String username = globalState.getOptions().getUserName(); + String password = globalState.getOptions().getPassword(); + if (password.equals("\"\"")) { + password = ""; + } + String host = globalState.getOptions().getHost(); + int port = globalState.getOptions().getPort(); + if (host == null) { + host = DorisOptions.DEFAULT_HOST; + } + if (port == MainOptions.NO_SET_PORT) { + port = DorisOptions.DEFAULT_PORT; + } + String databaseName = globalState.getDatabaseName(); + globalState.getState().logStatement("DROP DATABASE IF EXISTS " + databaseName); + globalState.getState().logStatement("CREATE DATABASE " + databaseName); + globalState.getState().logStatement("USE " + databaseName); + String url = String.format("jdbc:mysql://%s:%d?serverTimezone=UTC&useSSL=false&allowPublicKeyRetrieval=true", + host, port); + Connection con = DriverManager.getConnection(url, username, password); + try (Statement s = con.createStatement()) { + s.execute("DROP DATABASE IF EXISTS " + databaseName); + } + try (Statement s = con.createStatement()) { + s.execute("CREATE DATABASE " + databaseName); + } + try (Statement s = con.createStatement()) { + s.execute("USE " + databaseName); + } + return new SQLConnection(con); + } + + @Override + public String getDBMSName() { + return "doris"; + } + +} diff --git a/src/sqlancer/doris/DorisSchema.java b/src/sqlancer/doris/DorisSchema.java new file mode 100644 index 000000000..70a61ee62 --- /dev/null +++ b/src/sqlancer/doris/DorisSchema.java @@ -0,0 +1,619 @@ +package sqlancer.doris; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.DBMSCommon; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractRowValue; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema.DorisTable; +import sqlancer.doris.ast.DorisConstant; + +public class DorisSchema extends AbstractSchema { + + public enum DorisTableDataModel { + UNIQUE, AGGREGATE, DUPLICATE; + + public static DorisTableDataModel getRandom() { + List validOptions = new ArrayList<>(Arrays.asList(values())); + if (DorisBugs.bug36072) { + validOptions.remove(AGGREGATE); + } + if (DorisBugs.bug36343) { + validOptions.remove(UNIQUE); + } + return Randomly.fromList(validOptions); + } + } + + public enum DorisColumnAggrType { + SUM, MIN, MAX, REPLACE, REPLCAE_IF_NOT_NULL, BITMAP_UNION, HLL_UNION, NULL; + + public static DorisColumnAggrType getRandom(DorisCompositeDataType columnDataType) { + // if (columnDataType.getPrimitiveDataType() == DorisSchema.DorisDataType.BITMAP) { + // return DorisColumnAggrType.BITMAP_UNION; + // } + // if (columnDataType.getPrimitiveDataType() == DorisSchema.DorisDataType.HLL) { + // return DorisColumnAggrType.HLL_UNION; + // } + + return Randomly.fromOptions(SUM, MIN, MAX, REPLACE, REPLCAE_IF_NOT_NULL); + } + } + + public enum DorisDataType { + INT, FLOAT, DECIMAL, DATE, DATETIME, VARCHAR, BOOLEAN, NULL; + // HLL, BITMAP, ARRAY; + + private int decimalScale; + private int decimalPrecision; + private int varcharLength; + + public static DorisDataType getRandomWithoutNull() { + DorisDataType dt; + do { + dt = Randomly.fromOptions(values()); + } while (dt == DorisDataType.NULL); + return dt; + } + + public int getDecimalScale() { + return decimalScale; + } + + public void setDecimalScale(int decimalScale) { + this.decimalScale = decimalScale; + } + + public int getDecimalPrecision() { + return decimalPrecision; + } + + public void setDecimalPrecision(int decimalPrecision) { + this.decimalPrecision = decimalPrecision; + } + + public int getVarcharLength() { + return varcharLength; + } + + public void setVarcharLength(int varcharLength) { + this.varcharLength = varcharLength; + } + } + + public static class DorisCompositeDataType { + + private final DorisDataType dataType; + + private final int size; + + public DorisCompositeDataType(DorisDataType dataType, int size) { + this.dataType = dataType; + this.size = size; + } + + public DorisDataType getPrimitiveDataType() { + return dataType; + } + + public int getSize() { + if (size == -1) { + throw new AssertionError(this); + } + return size; + } + + public static DorisCompositeDataType getRandomWithoutNull() { + DorisDataType type = DorisDataType.getRandomWithoutNull(); + int size = -1; + switch (type) { + case INT: + size = Randomly.fromOptions(1, 2, 4, 8, 16); + break; + case FLOAT: + size = Randomly.fromOptions(4, 12); + break; + case DECIMAL: + size = Randomly.fromOptions(1, 3); // DECIMAL or DECIMALV3 + break; + case DATE: + case DATETIME: + case VARCHAR: + case BOOLEAN: + // case HLL: + // case BITMAP: + // case ARRAY: + size = 0; + break; + default: + throw new AssertionError(type); + } + + return new DorisCompositeDataType(type, size); + } + + public void initColumnArgs() { + Randomly r = new Randomly(); + int scale; + int precision; + int varcharLength; + switch (getPrimitiveDataType()) { + case DECIMAL: + if (getPrimitiveDataType().getDecimalPrecision() != 0) { + break; + } + if (size == 1) { + scale = r.getInteger(0, 9); + precision = r.getInteger(scale + 1, scale + 18); + getPrimitiveDataType().setDecimalPrecision(precision); + getPrimitiveDataType().setDecimalScale(scale); + } else if (size == 3) { + precision = r.getInteger(1, 38); + scale = r.getInteger(0, precision); + getPrimitiveDataType().setDecimalPrecision(precision); + getPrimitiveDataType().setDecimalScale(scale); + } else { + throw new AssertionError(size); + } + break; + case VARCHAR: + if (getPrimitiveDataType().getVarcharLength() != 0) { + break; + } + varcharLength = r.getInteger(1, 255); + getPrimitiveDataType().setVarcharLength(varcharLength); + break; + default: + // pass + } + + } + + @Override + public String toString() { + switch (getPrimitiveDataType()) { + case INT: + switch (size) { + case 16: + return "LARGEINT"; + case 8: + return "BIGINT"; + case 4: + return "INT"; + case 2: + return "SMALLINT"; + case 1: + return "TINYINT"; + default: + throw new AssertionError(size); + } + case FLOAT: + switch (size) { + case 12: + return "DOUBLE"; + case 4: + return "FLOAT"; + default: + throw new AssertionError(size); + } + case DECIMAL: + switch (size) { + case 1: + return "DECIMAL(" + getPrimitiveDataType().getDecimalPrecision() + "," + + getPrimitiveDataType().getDecimalScale() + ")"; + case 3: + return "DECIMALV3(" + getPrimitiveDataType().getDecimalPrecision() + "," + + getPrimitiveDataType().getDecimalScale() + ")"; + default: + throw new AssertionError(size); + } + case DATE: + return "DATEV2"; + case DATETIME: + return Randomly.fromOptions("DATETIME", "DATETIMEV2"); + case VARCHAR: + return Randomly.fromOptions("VARCHAR", "CHAR") + "(" + getPrimitiveDataType().getVarcharLength() + ")"; + case BOOLEAN: + return "BOOLEAN"; + // case HLL: + // return "HLL"; + // case BITMAP: + // return "BITMAP"; + // case ARRAY: + // return "ARRAY"; + case NULL: + return Randomly.fromOptions("NULL"); + default: + throw new AssertionError(getPrimitiveDataType()); + } + } + + public boolean canBeKey() { + switch (dataType) { + // case HLL: + // case BITMAP: + // case ARRAY: + case FLOAT: + return false; + default: + return true; + } + } + + } + + public static class DorisColumn extends AbstractTableColumn { + + private final boolean isKey; + private final boolean isNullable; + private final DorisColumnAggrType aggrType; + private final boolean hasDefaultValue; + private final String defaultValue; + + public DorisColumn(String name, DorisCompositeDataType type, boolean isKey, boolean isNullable, + DorisColumnAggrType aggrType, boolean hasDefaultValue, String defaultValue) { + super(name, null, type); + this.isKey = isKey; + this.isNullable = isNullable; + this.aggrType = aggrType; + this.hasDefaultValue = hasDefaultValue; + this.defaultValue = defaultValue; + } + + public DorisColumn(String name, DorisCompositeDataType type, boolean isKey, boolean isNullable) { + super(name, null, type); + this.isKey = isKey; + this.isNullable = isNullable; + this.aggrType = DorisColumnAggrType.NULL; + this.hasDefaultValue = false; + this.defaultValue = ""; + } + + public boolean isKey() { + return isKey; + } + + public boolean isNullable() { + return isNullable; + } + + public boolean hasDefaultValue() { + return hasDefaultValue; + } + + @Override + public String toString() { + String ret = this.getName() + " " + this.getType(); + if (aggrType != DorisColumnAggrType.NULL) { + ret += " " + aggrType.name(); + } + if (!isNullable) { + ret += " NOT NULL"; + } + if (hasDefaultValue) { + ret += " DEFAULT " + defaultValue; + } + return ret; + } + + @Override + public int compareTo(AbstractTableColumn o) { + // To sort columns + DorisColumn other = (DorisColumn) o; + if (isKey != other.isKey) { + return isKey ? 1 : -1; + } + return getName().compareTo(other.getName()); + } + } + + public static class DorisTables extends AbstractTables { + + public DorisTables(List tables) { + super(tables); + } + + public DorisRowValue getRandomRowValue(SQLConnection con) throws SQLException { + String rowValueQuery = String.format("SELECT %s FROM %s ORDER BY 1 LIMIT 1", columnNamesAsString( + c -> c.getTable().getName() + "." + c.getName() + " AS " + c.getTable().getName() + c.getName()), + tableNamesAsString()); + Map values = new HashMap<>(); + try (Statement s = con.createStatement()) { + ResultSet rs = s.executeQuery(rowValueQuery); + if (!rs.next()) { + throw new IgnoreMeException(); + // throw new AssertionError("could not find random row " + rowValueQuery + "\n"); + } + for (int i = 0; i < getColumns().size(); i++) { + DorisColumn column = getColumns().get(i); + int columnIndex = rs.findColumn(column.getTable().getName() + column.getName()); + assert columnIndex == i + 1; + DorisConstant constant; + if (rs.getString(columnIndex) == null) { + constant = DorisConstant.createNullConstant(); + } else { + switch (column.getType().getPrimitiveDataType()) { + case INT: + constant = DorisConstant.createIntConstant(rs.getLong(columnIndex)); + break; + case FLOAT: + case DECIMAL: + constant = DorisConstant.createFloatConstant(rs.getDouble(columnIndex)); + break; + case DATE: + constant = DorisConstant.createDateConstant(rs.getString(columnIndex)); + break; + case DATETIME: + constant = DorisConstant.createDatetimeConstant(rs.getString(columnIndex)); + break; + case VARCHAR: + constant = DorisConstant.createStringConstant(rs.getString(columnIndex)); + break; + case BOOLEAN: + constant = DorisConstant.createBooleanConstant(rs.getBoolean(columnIndex)); + break; + case NULL: + constant = DorisConstant.createNullConstant(); + break; + default: + throw new IgnoreMeException(); + } + } + values.put(column, constant); + } + assert !rs.next(); + return new DorisRowValue(this, values); + } catch (SQLException e) { + throw new IgnoreMeException(); + } + } + + } + + public static class DorisRowValue extends AbstractRowValue { + + DorisRowValue(DorisTables tables, Map values) { + super(tables, values); + } + + } + + public DorisSchema(List databaseTables) { + super(databaseTables); + } + + public DorisTables getRandomTableNonEmptyTables() { + return new DorisTables(Randomly.nonEmptySubset(getDatabaseTables())); + } + + public DorisTables getRandomTableNonEmptyAndViewTables() { + List tables = getDatabaseTables().stream().filter(t -> !t.isView()).collect(Collectors.toList()); + tables = Randomly.nonEmptySubset(tables); + return new DorisTables(tables); + } + + public int getIndexCount() { + int count = 0; + for (DorisTable table : getDatabaseTables()) { + count += table.getIndexes().size(); + } + return count; + } + + private static DorisCompositeDataType getColumnType(String typeString) { + DorisDataType primitiveType; + int size = -1; + + if (typeString.startsWith("DECIMALV3")) { + primitiveType = DorisDataType.DECIMAL; + String precisionAndScale = typeString.substring(typeString.indexOf('(') + 1, typeString.indexOf(')')); + String[] split = precisionAndScale.split(","); + assert split.length == 2; + primitiveType.setDecimalPrecision(Integer.parseInt(split[0].trim())); + primitiveType.setDecimalScale(Integer.parseInt(split[1].trim())); + size = 3; + } else if (typeString.startsWith("DECIMAL")) { + primitiveType = DorisDataType.DECIMAL; + String precisionAndScale = typeString.substring(typeString.indexOf('(') + 1, typeString.indexOf(')')); + String[] split = precisionAndScale.split(","); + assert split.length == 2; + primitiveType.setDecimalPrecision(Integer.parseInt(split[0].trim())); + primitiveType.setDecimalScale(Integer.parseInt(split[1].trim())); + size = 1; + } else if (typeString.startsWith("DATEV2")) { + primitiveType = DorisDataType.DATE; + size = 2; + } else if (typeString.startsWith("DATE")) { + primitiveType = DorisDataType.DATE; + size = 1; + } else if (typeString.startsWith("DATETIMEV2")) { + primitiveType = DorisDataType.DATETIME; + size = 2; + } else if (typeString.startsWith("DATETIME")) { + primitiveType = DorisDataType.DATETIME; + size = 1; + } else if (typeString.startsWith("CHAR") || typeString.startsWith("VARCHAR")) { + primitiveType = DorisDataType.VARCHAR; + String varcharLength = typeString.substring(typeString.indexOf('(') + 1, typeString.indexOf(')')); + primitiveType.setVarcharLength(Integer.parseInt(varcharLength.trim())); + } else { + switch (typeString) { + case "LARGEINT": + primitiveType = DorisDataType.INT; + size = 16; + break; + case "BIGINT": + primitiveType = DorisDataType.INT; + size = 8; + break; + case "INT": + primitiveType = DorisDataType.INT; + size = 4; + break; + case "SMALLINT": + primitiveType = DorisDataType.INT; + size = 2; + break; + case "TINYINT": + primitiveType = DorisDataType.INT; + size = 1; + break; + case "DOUBLE": + primitiveType = DorisDataType.FLOAT; + size = 12; + break; + case "FLOAT": + primitiveType = DorisDataType.FLOAT; + size = 4; + break; + case "DECIMAL": + case "DECIMAL(*,*)": + primitiveType = DorisDataType.DECIMAL; + size = 1; + break; + case "DECIMALV3": + case "DECIMALV3(*,*)": + primitiveType = DorisDataType.DECIMAL; + size = 3; + break; + case "CHAR": + case "CHAR(*)": + case "VARCHAR": + case "VARCHAR(*)": + primitiveType = DorisDataType.VARCHAR; + break; + case "DATE": + primitiveType = DorisDataType.DATE; + size = 1; + break; + case "DATEV2": + primitiveType = DorisDataType.DATE; + size = 2; + break; + case "DATETIME": + primitiveType = DorisDataType.DATETIME; + size = 1; + break; + case "DATETIMEV2": + primitiveType = DorisDataType.DATETIME; + size = 2; + break; + case "BOOLEAN": + primitiveType = DorisDataType.BOOLEAN; + break; + // case "HLL": + // primitiveType = DorisDataType.HLL; + // break; + // case "BITMAP": + // primitiveType = DorisDataType.BITMAP; + // break; + case "NULL": + primitiveType = DorisDataType.NULL; + break; + default: + throw new AssertionError(typeString); + } + } + return new DorisCompositeDataType(primitiveType, size); + } + + public static class DorisTable extends AbstractRelationalTable { + + public DorisTable(String tableName, List columns, boolean isView) { + super(tableName, columns, Collections.emptyList(), isView); + } + + public List getRandomNonEmptyInsertColumns() { + List columns = getColumns(); + List retColumns = new ArrayList<>(); + List remainColumns = new ArrayList<>(); + for (DorisColumn column : columns) { + if (!column.hasDefaultValue() && !column.isNullable) { + retColumns.add(column); + } else { + remainColumns.add(column); + } + } + if (retColumns.isEmpty()) { + retColumns.addAll(Randomly.nonEmptySubset(remainColumns)); + } else { + retColumns.addAll(Randomly.subset(remainColumns)); + } + return retColumns; + } + + } + + public static DorisSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + List databaseTables = new ArrayList<>(); + List tableNames = getTableNames(con); + for (String tableName : tableNames) { + if (DBMSCommon.matchesIndexName(tableName)) { + continue; + } + List databaseColumns = getTableColumns(con, tableName); + boolean isView = matchesViewName(tableName); + DorisTable t = new DorisTable(tableName, databaseColumns, isView); + for (DorisColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + + } + return new DorisSchema(databaseTables); + } + + private static List getTableNames(SQLConnection con) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery("SHOW TABLES")) { + while (rs.next()) { + tableNames.add(rs.getString(1)); + } + } + } + return tableNames; + } + + private static List getTableColumns(SQLConnection con, String tableName) throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery("DESCRIBE " + tableName)) { + while (rs.next()) { + String columnName = rs.getString("Field"); + String dataType = rs.getString("Type"); + String isNullString = rs.getString("Null"); + assert isNullString.contentEquals("Yes") || isNullString.contentEquals("No"); + boolean isNullable = isNullString.contentEquals("Yes"); + String isKeyString = rs.getString("Key"); + assert isKeyString.contentEquals("true") || isKeyString.contentEquals("false"); + boolean isKey = isKeyString.contentEquals("true"); + String defaultValue = rs.getString("Default"); + boolean hasDefaultValue = defaultValue != null; + DorisColumn c = new DorisColumn(columnName, getColumnType(dataType), isKey, isNullable, + DorisColumnAggrType.NULL, hasDefaultValue, defaultValue); + columns.add(c); + } + } + } + return columns; + } + +} diff --git a/src/sqlancer/doris/ast/DorisAggregateOperation.java b/src/sqlancer/doris/ast/DorisAggregateOperation.java new file mode 100644 index 000000000..4c60985a8 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisAggregateOperation.java @@ -0,0 +1,37 @@ +package sqlancer.doris.ast; + +import java.util.List; + +import sqlancer.Randomly; + +public class DorisAggregateOperation extends DorisFunction + implements DorisExpression { + + public DorisAggregateOperation(List args, DorisAggregateFunction func) { + super(args, func); + } + + public enum DorisAggregateFunction { + COLLECT_SET(1), MIN(1), STDDEV_SAMP(1), AVG(1), AVG_WEIGHTED(2), PERCENTILE(1), PERCENTILE_ARRAY(2), + HLL_UNION_AGG(1), TOPN(2), TOPN_ARRAY(2), TOPN_WEIGHTED(3), COUNT(1), SUM(1), MAX_BY(2), BITMAP_UNION(1), + GROUP_BITMAP_XOR(1), GROUP_BIT_ADD(1), GROUP_BIT_OR(1), GROUP_BIT_XOR(1), PERCENTILE_APPROX(2), STDDEV(1), + STDDEV_POP(1), GROUP_CONCAT(1), COLLECT_LIST(1), MIN_BY(2), MAX(1), ANY_VALUE(1), VAR_SAMP(1), VARIANCE_SAMP(1), + APPROX_COUNT_DISTINCT(1), VARIANCE(1), VAR_POP(1), VARIANCE_POP(1), GROUPING(1), GROUPING_ID(1); + // RETENTION(1), SEQUENCE_MATCH(1), SEQUENCE_COUNT(1), // TODO,not currently considered + + private int nrArgs; + + DorisAggregateFunction(int nrArgs) { + this.nrArgs = nrArgs; + } + + public static DorisAggregateFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + return nrArgs; + } + + } +} diff --git a/src/sqlancer/doris/ast/DorisAlias.java b/src/sqlancer/doris/ast/DorisAlias.java new file mode 100644 index 000000000..74acf1cff --- /dev/null +++ b/src/sqlancer/doris/ast/DorisAlias.java @@ -0,0 +1,9 @@ +package sqlancer.doris.ast; + +import sqlancer.common.ast.newast.NewAliasNode; + +public class DorisAlias extends NewAliasNode implements DorisExpression { + public DorisAlias(DorisExpression expr, String text) { + super(expr, text); + } +} diff --git a/src/sqlancer/doris/ast/DorisBetweenOperation.java b/src/sqlancer/doris/ast/DorisBetweenOperation.java new file mode 100644 index 000000000..96c0b428c --- /dev/null +++ b/src/sqlancer/doris/ast/DorisBetweenOperation.java @@ -0,0 +1,38 @@ +package sqlancer.doris.ast; + +import sqlancer.common.ast.newast.NewBetweenOperatorNode; +import sqlancer.doris.DorisSchema; + +public class DorisBetweenOperation extends NewBetweenOperatorNode implements DorisExpression { + public DorisBetweenOperation(DorisExpression left, DorisExpression middle, DorisExpression right, boolean isTrue) { + super(left, middle, right, isTrue); + } + + public DorisExpression getLeftExpr() { + return left; + } + + public DorisExpression getMiddleExpr() { + return middle; + } + + public DorisExpression getRightExpr() { + return right; + } + + @Override + public DorisConstant getExpectedValue() { + DorisBinaryComparisonOperation leftComparison = new DorisBinaryComparisonOperation(getMiddleExpr(), + getLeftExpr(), DorisBinaryComparisonOperation.DorisBinaryComparisonOperator.LESS_EQUALS); + DorisBinaryComparisonOperation rightComparison = new DorisBinaryComparisonOperation(getLeftExpr(), + getRightExpr(), DorisBinaryComparisonOperation.DorisBinaryComparisonOperator.LESS_EQUALS); + return new DorisBinaryLogicalOperation(leftComparison, rightComparison, + DorisBinaryLogicalOperation.DorisBinaryLogicalOperator.AND).getExpectedValue(); + } + + @Override + public DorisSchema.DorisDataType getExpectedType() { + return DorisSchema.DorisDataType.BOOLEAN; + } + +} diff --git a/src/sqlancer/doris/ast/DorisBinaryArithmeticOperation.java b/src/sqlancer/doris/ast/DorisBinaryArithmeticOperation.java new file mode 100644 index 000000000..2e148c488 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisBinaryArithmeticOperation.java @@ -0,0 +1,142 @@ +package sqlancer.doris.ast; + +import java.util.function.BinaryOperator; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.doris.DorisSchema.DorisDataType; + +public class DorisBinaryArithmeticOperation extends NewBinaryOperatorNode implements DorisExpression { + + public DorisBinaryArithmeticOperation(DorisExpression left, DorisExpression right, BinaryOperatorNode.Operator op) { + super(left, right, op); + } + + public enum DorisBinaryArithmeticOperator implements BinaryOperatorNode.Operator { + ADDITION("+") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + return applyOperation(left, right, (l, r) -> l + r); + } + }, + SUBTRACTION("-") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + return applyOperation(left, right, (l, r) -> l - r); + } + }, + MULTIPLICATION("*") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + return applyOperation(left, right, (l, r) -> l * r); + } + }, + DIVISION("/") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + return applyOperation(left, right, (l, r) -> r == 0 ? -1 : l / r); + } + }, + MODULO("%") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + return applyOperation(left, right, (l, r) -> r == 0 ? -1 : l % r); + } + }, + CONCAT("||") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + if (!left.isBoolean() || !right.isBoolean()) { + return DorisConstant.createNullConstant(); + } + return applyOperation(left, right, (l, r) -> l == 1 || r == 1 ? 1.0 : 0.0); + } + }, + BIT_AND("&") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + if (!left.isInt() || !right.isInt()) { + return DorisConstant.createNullConstant(); + } + return applyOperation(left, right, (l, r) -> (double) ((int) l.doubleValue() & (int) r.doubleValue())); + } + }, + BIT_OR("|") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + if (!left.isInt() || !right.isInt()) { + return DorisConstant.createNullConstant(); + } + return applyOperation(left, right, (l, r) -> (double) ((int) l.doubleValue() | (int) r.doubleValue())); + } + }, + LSHIFT("<<") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + if (!left.isInt() || !right.isInt()) { + return DorisConstant.createNullConstant(); + } + return applyOperation(left, right, (l, r) -> (double) ((int) l.doubleValue() << (int) r.doubleValue())); + } + }, + RSHIFT(">>") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + if (!left.isInt() || !right.isInt()) { + return DorisConstant.createNullConstant(); + } + return applyOperation(left, right, (l, r) -> (double) ((int) l.doubleValue() >> (int) r.doubleValue())); + } + }; + + private final String textRepresentation; + + DorisBinaryArithmeticOperator(String text) { + textRepresentation = text; + } + + public abstract DorisConstant apply(DorisConstant left, DorisConstant right); + + public DorisConstant applyOperation(DorisConstant left, DorisConstant right, BinaryOperator op) { + if (left.isNull() || right.isNull()) { + return DorisConstant.createNullConstant(); + } + double leftVal = left.cast(DorisDataType.FLOAT).asFloat(); + double rightVal = right.cast(DorisDataType.FLOAT).asFloat(); + return DorisConstant.createFloatConstant(op.apply(leftVal, rightVal)); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + } + + public DorisExpression getLeftExpr() { + return super.getLeft(); + } + + public DorisExpression getRightExpr() { + return super.getRight(); + } + + public DorisBinaryArithmeticOperator getOp() { + return (DorisBinaryArithmeticOperator) op; + } + + @Override + public DorisConstant getExpectedValue() { + DorisConstant leftValue = getLeftExpr().getExpectedValue(); + DorisConstant rightValue = getRightExpr().getExpectedValue(); + if (leftValue == null || rightValue == null) { + return null; + } + return getOp().apply(leftValue, rightValue); + } + + @Override + public DorisDataType getExpectedType() { + return DorisDataType.FLOAT; + } + +} diff --git a/src/sqlancer/doris/ast/DorisBinaryComparisonOperation.java b/src/sqlancer/doris/ast/DorisBinaryComparisonOperation.java new file mode 100644 index 000000000..13355c7ff --- /dev/null +++ b/src/sqlancer/doris/ast/DorisBinaryComparisonOperation.java @@ -0,0 +1,119 @@ +package sqlancer.doris.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.doris.DorisSchema.DorisDataType; + +public class DorisBinaryComparisonOperation extends NewBinaryOperatorNode implements DorisExpression { + + public DorisBinaryComparisonOperation(DorisExpression left, DorisExpression right, + DorisBinaryComparisonOperator op) { + super(left, right, op); + } + + public DorisExpression getLeftExpression() { + return super.getLeft(); + } + + public DorisExpression getRightExpression() { + return super.getRight(); + } + + public DorisBinaryComparisonOperator getOp() { + return (DorisBinaryComparisonOperator) op; + } + + @Override + public DorisDataType getExpectedType() { + return DorisDataType.BOOLEAN; + } + + @Override + public DorisConstant getExpectedValue() { + DorisConstant leftExpectedValue = getLeftExpression().getExpectedValue(); + DorisConstant rightExpectedValue = getRightExpression().getExpectedValue(); + if (leftExpectedValue == null || rightExpectedValue == null) { + return null; + } + return getOp().apply(leftExpectedValue, rightExpectedValue); + } + + public enum DorisBinaryComparisonOperator implements BinaryOperatorNode.Operator { + EQUALS("=") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + return left.valueEquals(right); + } + }, + NOT_EQUALS("!=") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + DorisConstant valueEquals = left.valueEquals(right); + if (valueEquals.isBoolean()) { + return DorisConstant.createBooleanConstant(!valueEquals.asBoolean()); + } + // maybe DorisNULLConstant or null object + return valueEquals; + } + }, + LESS("<") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + return left.valueLessThan(right); + } + }, + LESS_EQUALS("<=") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + DorisConstant valueLessThan = left.valueLessThan(right); + DorisConstant valueEquals = left.valueEquals(right); + if (valueEquals.isBoolean() && valueEquals.asBoolean()) { + return valueEquals; + } + return valueLessThan; + } + }, + GREATER(">") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + DorisConstant valueLessThan = left.valueLessThan(right); + DorisConstant valueEquals = left.valueEquals(right); + if (valueEquals.isBoolean() && valueEquals.asBoolean()) { + return DorisConstant.createBooleanConstant(false); + } + if (valueLessThan.isNull()) { + return valueLessThan; + } + return DorisConstant.createBooleanConstant(!valueLessThan.asBoolean()); + } + }, + GREATER_EQUALS(">=") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + DorisConstant valueLessThan = left.valueLessThan(right); + DorisConstant valueEquals = left.valueEquals(right); + if (valueEquals.isBoolean() && valueEquals.asBoolean()) { + return DorisConstant.createBooleanConstant(true); + } + if (valueLessThan.isNull()) { + return valueLessThan; + } + return DorisConstant.createBooleanConstant(!valueLessThan.asBoolean()); + } + }; + + private final String textRepresentation; + + DorisBinaryComparisonOperator(String text) { + textRepresentation = text; + } + + public abstract DorisConstant apply(DorisConstant left, DorisConstant right); + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + } + +} diff --git a/src/sqlancer/doris/ast/DorisBinaryLogicalOperation.java b/src/sqlancer/doris/ast/DorisBinaryLogicalOperation.java new file mode 100644 index 000000000..858c3bf2f --- /dev/null +++ b/src/sqlancer/doris/ast/DorisBinaryLogicalOperation.java @@ -0,0 +1,129 @@ +package sqlancer.doris.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.doris.DorisSchema.DorisDataType; + +public class DorisBinaryLogicalOperation extends NewBinaryOperatorNode implements DorisExpression { + + public DorisBinaryLogicalOperation(DorisExpression left, DorisExpression right, DorisBinaryLogicalOperator op) { + super(left, right, op); + } + + public DorisExpression getLeftExpr() { + return super.getLeft(); + } + + public DorisExpression getRightExpr() { + return super.getRight(); + } + + public DorisBinaryLogicalOperator getOp() { + return (DorisBinaryLogicalOperator) op; + } + + @Override + public DorisConstant getExpectedValue() { + DorisConstant leftValue = getLeftExpr().getExpectedValue(); + DorisConstant rightValue = getRightExpr().getExpectedValue(); + if (leftValue == null || rightValue == null) { + return null; + } + return getOp().apply(leftValue, rightValue); + } + + @Override + public DorisDataType getExpectedType() { + return DorisDataType.BOOLEAN; + } + + public enum DorisBinaryLogicalOperator implements BinaryOperatorNode.Operator { + /* + * null and false -> false null and true -> null null or false -> null null or true -> true + */ + AND("AND", "and") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + DorisConstant leftVal = left.cast(DorisDataType.BOOLEAN); + DorisConstant rightVal = right.cast(DorisDataType.BOOLEAN); + assert leftVal.isNull() || leftVal.isBoolean() : leftVal + "is not null or boolean"; + assert rightVal.isNull() || rightVal.isBoolean() : rightVal + "is not null or boolean"; + if (leftVal.isNull() && rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (leftVal.isNull()) { + if (!rightVal.asBoolean()) { + return DorisConstant.createBooleanConstant(false); + } else { + return DorisConstant.createNullConstant(); + } + } + if (rightVal.isNull()) { + if (!leftVal.asBoolean()) { + return DorisConstant.createBooleanConstant(false); + } else { + return DorisConstant.createNullConstant(); + } + } + if (leftVal.asBoolean() && right.asBoolean()) { + return DorisConstant.createBooleanConstant(true); + } + return DorisConstant.createBooleanConstant(false); + } + }, + OR("OR", "or") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + DorisConstant leftVal = left.cast(DorisDataType.BOOLEAN); + DorisConstant rightVal = right.cast(DorisDataType.BOOLEAN); + assert leftVal.isNull() || leftVal.isBoolean() : leftVal + "is not null or boolean"; + assert rightVal.isNull() || rightVal.isBoolean() : rightVal + "is not null or boolean"; + if (leftVal.isNull() && rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (leftVal.isNull()) { + if (rightVal.asBoolean()) { + return DorisConstant.createBooleanConstant(true); + } else { + return DorisConstant.createNullConstant(); + } + } + if (rightVal.isNull()) { + if (leftVal.asBoolean()) { + return DorisConstant.createBooleanConstant(true); + } else { + return DorisConstant.createNullConstant(); + } + } + if (leftVal.asBoolean() || right.asBoolean()) { + return DorisConstant.createBooleanConstant(true); + } + return DorisConstant.createBooleanConstant(false); + } + }; + + private final String[] textRepresentations; + + DorisBinaryLogicalOperator(String... textRepresentations) { + this.textRepresentations = textRepresentations.clone(); + } + + @Override + public String getTextRepresentation() { + return Randomly.fromOptions(textRepresentations); + } + + public DorisBinaryLogicalOperator getRandomOp() { + return Randomly.fromOptions(values()); + } + + public static DorisBinaryLogicalOperator getRandom() { + return Randomly.fromOptions(values()); + } + + public abstract DorisConstant apply(DorisConstant left, DorisConstant right); + + } + +} diff --git a/src/sqlancer/doris/ast/DorisBinaryOperation.java b/src/sqlancer/doris/ast/DorisBinaryOperation.java new file mode 100644 index 000000000..ffe80b86d --- /dev/null +++ b/src/sqlancer/doris/ast/DorisBinaryOperation.java @@ -0,0 +1,10 @@ +package sqlancer.doris.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class DorisBinaryOperation extends NewBinaryOperatorNode implements DorisExpression { + public DorisBinaryOperation(DorisExpression left, DorisExpression right, BinaryOperatorNode.Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/doris/ast/DorisCaseOperation.java b/src/sqlancer/doris/ast/DorisCaseOperation.java new file mode 100644 index 000000000..02641408a --- /dev/null +++ b/src/sqlancer/doris/ast/DorisCaseOperation.java @@ -0,0 +1,36 @@ +package sqlancer.doris.ast; + +import java.util.List; + +public class DorisCaseOperation implements DorisExpression { + + private final DorisExpression expr; + private final List conditions; + private final List thenClauses; + private final DorisExpression elseClause; + + public DorisCaseOperation(DorisExpression expr, List conditions, List thenClauses, + DorisExpression elseClause) { + this.expr = expr; + this.conditions = conditions; + this.thenClauses = thenClauses; + this.elseClause = elseClause; + } + + public DorisExpression getExpr() { + return expr; + } + + public List getConditions() { + return conditions; + } + + public List getThenClauses() { + return thenClauses; + } + + public DorisExpression getElseClause() { + return elseClause; + } + +} diff --git a/src/sqlancer/doris/ast/DorisCastOperation.java b/src/sqlancer/doris/ast/DorisCastOperation.java new file mode 100644 index 000000000..e93b8b226 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisCastOperation.java @@ -0,0 +1,45 @@ +package sqlancer.doris.ast; + +import sqlancer.doris.DorisSchema.DorisCompositeDataType; +import sqlancer.doris.DorisSchema.DorisDataType; + +public class DorisCastOperation implements DorisExpression { + DorisExpression expr; + DorisDataType type; + + public DorisCastOperation(DorisExpression expr, DorisCompositeDataType type) { + this.expr = expr; + this.type = type.getPrimitiveDataType(); + } + + public DorisCastOperation(DorisExpression expr, DorisDataType type) { + this.expr = expr; + this.type = type; + } + + public DorisExpression getExpr() { + return expr; + } + + public DorisExpression getExpression() { + return expr; + } + + public DorisDataType getType() { + return type; + } + + @Override + public DorisConstant getExpectedValue() { + DorisConstant expectedValue = getExpression().getExpectedValue(); + if (expectedValue == null) { + return null; + } + return expectedValue.cast(type); + } + + @Override + public DorisDataType getExpectedType() { + return type; + } +} diff --git a/src/sqlancer/doris/ast/DorisColumnReference.java b/src/sqlancer/doris/ast/DorisColumnReference.java new file mode 100644 index 000000000..ce6aee058 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisColumnReference.java @@ -0,0 +1,11 @@ +package sqlancer.doris.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.doris.DorisSchema; + +public class DorisColumnReference extends ColumnReferenceNode + implements DorisExpression { + public DorisColumnReference(DorisSchema.DorisColumn column) { + super(column); + } +} diff --git a/src/sqlancer/doris/ast/DorisColumnValue.java b/src/sqlancer/doris/ast/DorisColumnValue.java new file mode 100644 index 000000000..fe3ec1448 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisColumnValue.java @@ -0,0 +1,53 @@ +package sqlancer.doris.ast; + +import java.util.Objects; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisDataType; + +public class DorisColumnValue extends ColumnReferenceNode implements DorisExpression { + + private final DorisConstant expectedValue; + + public DorisColumnValue(DorisColumn column, DorisConstant value) { + super(column); + this.expectedValue = value; + } + + @Override + public DorisConstant getExpectedValue() { + return expectedValue; + } + + @Override + public DorisDataType getExpectedType() { + return getColumn().getType().getPrimitiveDataType(); + } + + public static DorisColumnValue create(DorisColumn column, DorisConstant value) { + return new DorisColumnValue(column, value); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DorisColumnValue that = (DorisColumnValue) o; + if (!this.getColumn().getName().equals(that.getColumn().getName())) { + return false; + } + return Objects.equals(expectedValue, that.expectedValue); + } + + @Override + public int hashCode() { + String nameAndValue = this.getColumn().getName(); + nameAndValue += expectedValue == null ? "NULL" : expectedValue.toString(); + return Objects.hash(nameAndValue); + } +} diff --git a/src/sqlancer/doris/ast/DorisConstant.java b/src/sqlancer/doris/ast/DorisConstant.java new file mode 100644 index 000000000..2ad2500ed --- /dev/null +++ b/src/sqlancer/doris/ast/DorisConstant.java @@ -0,0 +1,680 @@ +package sqlancer.doris.ast; + +import sqlancer.doris.DorisSchema.DorisDataType; +import sqlancer.doris.utils.DorisNumberUtils; + +public abstract class DorisConstant implements DorisExpression { + + private DorisConstant() { + } + + public boolean isNull() { + return false; + } + + public boolean isInt() { + return false; + } + + public boolean isBoolean() { + return false; + } + + public boolean isString() { + return false; + } + + public boolean isFloat() { + return false; + } + + public boolean isNum() { + // for INT, FLOAT, BOOLEAN + return false; + } + + public boolean isDate() { + return false; + } + + public boolean isDatetime() { + return false; + } + + public boolean asBoolean() { + throw new UnsupportedOperationException(this.toString()); + } + + public long asInt() { + throw new UnsupportedOperationException(this.toString()); + } + + public String asString() { + throw new UnsupportedOperationException(this.toString()); + } + + public double asFloat() { + throw new UnsupportedOperationException(this.toString()); + } + + public abstract DorisConstant cast(DorisDataType dataType); + + public abstract DorisConstant valueEquals(DorisConstant rightVal); + + public abstract DorisConstant valueLessThan(DorisConstant rightVal); + + public static class DorisNullConstant extends DorisConstant { + + @Override + public String toString() { + return "NULL"; + } + + @Override + public boolean isNull() { + return true; + } + + @Override + public DorisConstant cast(DorisDataType dataType) { + return DorisConstant.createNullConstant(); + } + + @Override + public DorisConstant valueEquals(DorisConstant rightVal) { + return DorisConstant.createNullConstant(); + } + + @Override + public DorisConstant valueLessThan(DorisConstant rightVal) { + return DorisConstant.createNullConstant(); + } + + @Override + public DorisDataType getExpectedType() { + return DorisDataType.NULL; + } + } + + public static class DorisIntConstant extends DorisConstant { + + private final long value; + + public DorisIntConstant(long value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + public long getValue() { + return value; + } + + @Override + public boolean isInt() { + return true; + } + + @Override + public boolean isNum() { + return true; + } + + @Override + public DorisConstant cast(DorisDataType dataType) { + switch (dataType) { + case INT: + return this; + case FLOAT: + case DECIMAL: + return new DorisFloatConstant(value); + case VARCHAR: + return new DorisTextConstant(String.valueOf(value)); + case BOOLEAN: + return new DorisBooleanConstant(value != 0); + default: + return DorisConstant.createNullConstant(); + } + } + + @Override + public long asInt() { + return value; + } + + @Override + public boolean asBoolean() { + return value != 0; + } + + @Override + public double asFloat() { + return value; + } + + @Override + public String asString() { + return String.valueOf(value); + } + + @Override + public DorisConstant valueEquals(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isNum()) { + return DorisConstant.createBooleanConstant(value == rightVal.asFloat()); + } + + throw new AssertionError(rightVal); + } + + @Override + public DorisConstant valueLessThan(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isNum()) { + return DorisConstant.createBooleanConstant(value < rightVal.asFloat()); + } + + throw new AssertionError(rightVal); + } + + @Override + public DorisDataType getExpectedType() { + return DorisDataType.INT; + } + + } + + public static class DorisFloatConstant extends DorisConstant { + + private final double value; + + public DorisFloatConstant(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + @Override + public boolean isFloat() { + return true; + } + + @Override + public boolean isNum() { + return true; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "3.40282347e+38"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "-3.40282347e+38"; + } + + return String.valueOf(value); + } + + @Override + public DorisConstant cast(DorisDataType dataType) { + switch (dataType) { + case INT: + return new DorisIntConstant((long) value); + case FLOAT: + case DECIMAL: + return this; + case VARCHAR: + return new DorisTextConstant(String.valueOf(value)); + case BOOLEAN: + return new DorisBooleanConstant(value >= 1); + default: + return null; + } + } + + @Override + public double asFloat() { + return value; + } + + @Override + public String asString() { + return toString(); + } + + @Override + public DorisConstant valueEquals(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return DorisConstant.createBooleanConstant(value == rightVal.asInt()); + } else if (rightVal.isFloat()) { + return DorisConstant.createBooleanConstant(value < rightVal.asFloat()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + public DorisConstant valueLessThan(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return DorisConstant.createBooleanConstant(value < rightVal.asInt()); + } else if (rightVal.isFloat()) { + return DorisConstant.createBooleanConstant(value < rightVal.asFloat()); + } else { + throw new AssertionError(rightVal); + } + } + + } + + public static class DorisTextConstant extends DorisConstant { + + private final String value; + + public DorisTextConstant(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "'" + value.replace("\\", "\\\\").replace("'", "\\'") + "'"; + } + + @Override + public String asString() { + return value; + } + + @Override + public boolean isString() { + return true; + } + + @Override + public DorisConstant cast(DorisDataType dataType) { + switch (dataType) { + case INT: + // Currently only supports conversion of int text to int, not float text, see + // https://github.com/apache/doris/issues/18227 + if (DorisNumberUtils.isNumber(value)) { + long val = (long) Double.parseDouble(value); + return new DorisIntConstant(val); + } + return new DorisNullConstant(); + case FLOAT: + case DECIMAL: + if (DorisNumberUtils.isNumber(value)) { + return new DorisFloatConstant(Double.parseDouble(value)); + } + return new DorisNullConstant(); + case DATE: + if (DorisNumberUtils.isDate(value)) { + return new DorisDateConstant(value); + } + return new DorisNullConstant(); + case DATETIME: + if (DorisNumberUtils.isDatetime(value)) { + return new DorisDatetimeConstant(value); + } + return new DorisNullConstant(); + case VARCHAR: + return this; + case BOOLEAN: + if ("false".contentEquals(value.toLowerCase())) { + return new DorisBooleanConstant(false); + } + if ("true".contentEquals(value.toLowerCase())) { + return new DorisBooleanConstant(true); + } + return new DorisNullConstant(); + default: + return new DorisNullConstant(); + } + } + + @Override + public DorisConstant valueEquals(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isString()) { + return DorisConstant.createBooleanConstant(value.contentEquals(rightVal.asString())); + } + if (DorisNumberUtils.isNumber(value) && rightVal.isNum()) { + return DorisConstant.createBooleanConstant(Double.parseDouble(value) == rightVal.asFloat()); + } + // Doris currently does not support judgment between string types and other types, such date, datetime + return DorisConstant.createBooleanConstant(false); + } + + @Override + public DorisConstant valueLessThan(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isString()) { + return DorisConstant.createBooleanConstant(value.compareTo(rightVal.asString()) < 0); + } + if (DorisNumberUtils.isNumber(value) && rightVal.isNum()) { + return DorisConstant.createBooleanConstant(Double.parseDouble(value) < rightVal.asFloat()); + } + // Doris currently does not support judgment between string types and other types, such date, datetime + return DorisConstant.createBooleanConstant(false); + } + + } + + public static class DorisDateConstant extends DorisConstant { + + public String textRepr; + + public DorisDateConstant(long val) { + textRepr = DorisNumberUtils.timestampToDateText(val); + } + + public DorisDateConstant(String textRepr) { + this.textRepr = textRepr; + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("DATE '%s'", textRepr); + } + + @Override + public String asString() { + return textRepr; + } + + @Override + public DorisConstant cast(DorisDataType dataType) { + switch (dataType) { + case VARCHAR: + return new DorisTextConstant(textRepr); + case DATE: + return this; + case DATETIME: + return new DorisDatetimeConstant(DorisNumberUtils.dateTextToDatetimeText(textRepr)); + default: + return new DorisNullConstant(); + } + } + + @Override + public DorisConstant valueEquals(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isDatetime() && rightVal.asString().contentEquals("CURRENT_TIMESTAMP")) { + return DorisConstant.createBooleanConstant(false); + } + if (rightVal.isString() || rightVal.isDate() || rightVal.isDatetime()) { + return DorisConstant.createBooleanConstant(DorisNumberUtils.dateEqual(textRepr, rightVal.asString())); + } + return DorisConstant.createBooleanConstant(false); + } + + @Override + public DorisConstant valueLessThan(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isDatetime() && rightVal.asString().contentEquals("CURRENT_TIMESTAMP")) { + return DorisConstant.createBooleanConstant( + DorisNumberUtils.dateLessThan(textRepr, DorisNumberUtils.getCurrentTimeText())); + } + if (rightVal.isString() || rightVal.isDate() || rightVal.isDatetime()) { + return DorisConstant + .createBooleanConstant(DorisNumberUtils.dateLessThan(textRepr, rightVal.asString())); + } + return DorisConstant.createBooleanConstant(false); + } + + @Override + public boolean isDate() { + return true; + } + } + + public static class DorisDatetimeConstant extends DorisConstant { + + public String textRepr; + + public DorisDatetimeConstant(long val) { + textRepr = DorisNumberUtils.timestampToDatetimeText(val); + } + + public DorisDatetimeConstant(String textRepr) { + this.textRepr = textRepr; + } + + public DorisDatetimeConstant() { + textRepr = "CURRENT_TIMESTAMP"; + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("TIMESTAMP '%s'", textRepr); + } + + @Override + public String asString() { + return textRepr; + } + + @Override + public DorisConstant cast(DorisDataType dataType) { + switch (dataType) { + case VARCHAR: + return new DorisTextConstant(textRepr); + case DATE: + return new DorisDatetimeConstant(DorisNumberUtils.datetimeTextToDateText(textRepr)); + case DATETIME: + return this; + default: + return new DorisNullConstant(); + } + } + + @Override + public DorisConstant valueEquals(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isDatetime() && (rightVal.asString().contentEquals("CURRENT_TIMESTAMP") + || textRepr.contentEquals("CURRENT_TIMESTAMP"))) { + boolean isEq = rightVal.asString().contentEquals("CURRENT_TIMESTAMP") + && textRepr.contentEquals("CURRENT_TIMESTAMP"); + return DorisConstant.createBooleanConstant(isEq); + } + if (rightVal.isString() || rightVal.isDate() || rightVal.isDatetime()) { + return DorisConstant + .createBooleanConstant(DorisNumberUtils.datetimeEqual(textRepr, rightVal.asString())); + } + return DorisConstant.createBooleanConstant(false); + } + + @Override + public DorisConstant valueLessThan(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isDatetime() && (rightVal.asString().contentEquals("CURRENT_TIMESTAMP") + || textRepr.contentEquals("CURRENT_TIMESTAMP"))) { + String leftText = textRepr; + String rightText = rightVal.asString(); + if (leftText.contentEquals(rightText)) { + return DorisConstant.createBooleanConstant(false); + } + if (leftText.contentEquals("CURRENT_TIMESTAMP")) { + leftText = DorisNumberUtils.getCurrentTimeText(); + } + if (rightText.contentEquals("CURRENT_TIMESTAMP")) { + rightText = DorisNumberUtils.getCurrentTimeText(); + } + boolean lessThan = DorisNumberUtils.dateLessThan(leftText, rightText); + return DorisConstant.createBooleanConstant(lessThan); + } + if (rightVal.isString() || rightVal.isDate() || rightVal.isDatetime()) { + return DorisConstant + .createBooleanConstant(DorisNumberUtils.datetimeLessThan(textRepr, rightVal.asString())); + } + return DorisConstant.createBooleanConstant(false); + } + + @Override + public boolean isDatetime() { + return true; + } + + } + + public static class DorisBooleanConstant extends DorisConstant { + + private final boolean value; + + public DorisBooleanConstant(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public String asString() { + return toString(); + } + + @Override + public boolean asBoolean() { + return value; + } + + @Override + public boolean isBoolean() { + return true; + } + + @Override + public boolean isNum() { + return true; + } + + @Override + public DorisConstant cast(DorisDataType dataType) { + switch (dataType) { + case INT: + return new DorisIntConstant(value ? 1 : 0); + case FLOAT: + case DECIMAL: + return new DorisFloatConstant(value ? 1 : 0); + case BOOLEAN: + return this; + case VARCHAR: + return new DorisTextConstant(value ? "1" : "0"); + default: + return null; + } + } + + @Override + public DorisConstant valueEquals(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isBoolean()) { + return DorisConstant.createBooleanConstant(value == rightVal.asBoolean()); + } + if (rightVal.isNum() || rightVal.isString() && DorisNumberUtils.isNumber(rightVal.asString())) { + return DorisConstant.createBooleanConstant((value ? 1 : 0) == rightVal.asFloat()); + } + throw new AssertionError(rightVal); + } + + @Override + public DorisConstant valueLessThan(DorisConstant rightVal) { + if (rightVal.isNull()) { + return DorisConstant.createNullConstant(); + } + if (rightVal.isBoolean()) { + return DorisConstant.createBooleanConstant(value == rightVal.asBoolean()); + } + if (rightVal.isNum() || rightVal.isString() && DorisNumberUtils.isNumber(rightVal.asString())) { + return DorisConstant.createBooleanConstant((value ? 1 : 0) == rightVal.asFloat()); + } + throw new AssertionError(rightVal); + } + } + + public static DorisConstant createStringConstant(String text) { + return new DorisTextConstant(text); + } + + public static DorisConstant createFloatConstant(double val) { + return new DorisFloatConstant(val); + } + + public static DorisConstant createIntConstant(long val) { + return new DorisIntConstant(val); + } + + public static DorisConstant createNullConstant() { + return new DorisNullConstant(); + } + + public static DorisConstant createBooleanConstant(boolean val) { + return new DorisBooleanConstant(val); + } + + public static DorisConstant createDateConstant(long integer) { + return new DorisDateConstant(integer); + } + + public static DorisConstant createDateConstant(String date) { + return new DorisDateConstant(date); + } + + public static DorisConstant createDatetimeConstant(long integer) { + return new DorisDatetimeConstant(integer); + } + + public static DorisConstant createDatetimeConstant(String datetime) { + return new DorisDatetimeConstant(datetime); + } + + public static DorisConstant createDatetimeConstant() { + // use CURRENT_TIMESTAMP + return new DorisDatetimeConstant(); + } + +} diff --git a/src/sqlancer/doris/ast/DorisExpression.java b/src/sqlancer/doris/ast/DorisExpression.java new file mode 100644 index 000000000..ee5cc7a26 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisExpression.java @@ -0,0 +1,15 @@ +package sqlancer.doris.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.doris.DorisSchema; +import sqlancer.doris.DorisSchema.DorisColumn; + +public interface DorisExpression extends Expression { + default DorisSchema.DorisDataType getExpectedType() { + return null; + } + + default DorisConstant getExpectedValue() { + return null; + } +} diff --git a/src/sqlancer/doris/ast/DorisFunction.java b/src/sqlancer/doris/ast/DorisFunction.java new file mode 100644 index 000000000..1f816a7e3 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisFunction.java @@ -0,0 +1,11 @@ +package sqlancer.doris.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class DorisFunction extends NewFunctionNode implements DorisExpression { + public DorisFunction(List args, F func) { + super(args, func); + } +} diff --git a/src/sqlancer/doris/ast/DorisFunctionOperation.java b/src/sqlancer/doris/ast/DorisFunctionOperation.java new file mode 100644 index 000000000..05f5370b1 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisFunctionOperation.java @@ -0,0 +1,281 @@ +package sqlancer.doris.ast; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import sqlancer.Randomly; +import sqlancer.doris.DorisSchema.DorisDataType; +import sqlancer.doris.gen.DorisNewExpressionGenerator; + +public class DorisFunctionOperation implements DorisExpression { + + private DorisFunction function; + private List args; + + // https://doris.apache.org/zh-CN/docs/dev/summary/basic-summary + public enum DorisFunction { + + // Array functions, https://doris.apache.org/docs/dev/sql-manual/sql-functions/array-functions/array + // Skip now + + // Date functions, https://doris.apache.org/docs/dev/sql-manual/sql-functions/date-time-functions/convert_tz/ + CONVERT_TZ(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + CURDATE(false, DorisDataType.DATE), CURRENT_DATE(false, DorisDataType.DATE), + CURTIME(false, DorisDataType.VARCHAR), CURRENT_TIME(false, DorisDataType.VARCHAR), + CURRENT_TIMESTAMP(false, DorisDataType.DATETIME), LOCALTIME(false, DorisDataType.DATETIME), + LOCALTIMESTAMP(false, DorisDataType.DATETIME), NOW(false, DorisDataType.DATETIME), + YEAR(false, DorisDataType.INT, DorisDataType.DATETIME), + QUARTER(false, DorisDataType.INT, DorisDataType.DATETIME), + MONTH(false, DorisDataType.INT, DorisDataType.DATETIME), DAY(false, DorisDataType.INT, DorisDataType.DATETIME), + DAYOFYEAR(false, DorisDataType.INT, DorisDataType.DATETIME), + DAYOFMONTH(false, DorisDataType.INT, DorisDataType.DATETIME), + DAYOFWEEK(false, DorisDataType.INT, DorisDataType.DATETIME), WEEK(false, DorisDataType.INT, DorisDataType.DATE), + WEEKDAY(false, DorisDataType.INT, DorisDataType.DATE), + WEEKOFYEAR(false, DorisDataType.INT, DorisDataType.DATETIME), + YEARWEEK(false, DorisDataType.INT, DorisDataType.DATE), + DAYNAME(false, DorisDataType.VARCHAR, DorisDataType.DATETIME), + MONTHNAME(false, DorisDataType.VARCHAR, DorisDataType.DATETIME), + HOUR(false, DorisDataType.INT, DorisDataType.DATETIME), + MINUTE(false, DorisDataType.INT, DorisDataType.DATETIME), + SECOND(false, DorisDataType.INT, DorisDataType.DATETIME), + FROM_DAYS(false, DorisDataType.DATE, DorisDataType.INT), + LAST_DAYS(false, DorisDataType.DATE, DorisDataType.DATETIME), + TO_MONDAY(false, DorisDataType.DATE, DorisDataType.DATETIME), + FROM_UNIXTIME(false, DorisDataType.DATETIME, DorisDataType.INT), + UNIX_TIMESTAMP(false, DorisDataType.INT, DorisDataType.DATETIME), UTC_TIMESTAMP(false, DorisDataType.DATETIME), + TO_DATE(false, DorisDataType.DATE, DorisDataType.DATETIME), + TO_DAYS(false, DorisDataType.INT, DorisDataType.DATETIME), + TIME_TO_SEC(false, DorisDataType.INT, DorisDataType.DATETIME), + // EXTRACT(1), // select extract(year from '2022-09-22 17:01:30') as year, currently not considered + MAKEDATE(false, DorisDataType.DATE, DorisDataType.INT, DorisDataType.INT), + STR_TO_DATE(false, DorisDataType.DATETIME, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + TIME_ROUND(false, DorisDataType.DATETIME, DorisDataType.DATETIME), + TIME_DIFF(false, DorisDataType.VARCHAR, DorisDataType.DATETIME, DorisDataType.DATETIME), + TIMESTAMPADD(false, DorisDataType.DATETIME, DorisDataType.VARCHAR, DorisDataType.INT, DorisDataType.DATETIME), + TIMESTAMPDIFF(false, DorisDataType.VARCHAR, DorisDataType.DATETIME, DorisDataType.DATETIME), + DATE_ADD(false, DorisDataType.INT, DorisDataType.DATETIME, DorisDataType.VARCHAR), + DATE_SUB(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.VARCHAR), + DATE_TRUNC(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.VARCHAR), + DATE_FORMAT(false, DorisDataType.VARCHAR, DorisDataType.DATETIME, DorisDataType.VARCHAR), + DATEDIFF(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.DATETIME), + // MICROSECONDS_ADD(false), + MINUTES_ADD(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + MINUTES_DIFF(false, DorisDataType.INT, DorisDataType.DATETIME, DorisDataType.DATETIME), + MINUTES_SUB(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + SECONDS_ADD(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + SECONDS_DIFF(false, DorisDataType.INT, DorisDataType.DATETIME, DorisDataType.DATETIME), + SECONDS_SUB(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + HOURS_ADD(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + HOURS_DIFF(false, DorisDataType.INT, DorisDataType.DATETIME, DorisDataType.DATETIME), + HOURS_SUB(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + DAYS_ADD(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + DAYS_DIFF(false, DorisDataType.INT, DorisDataType.DATETIME, DorisDataType.DATETIME), + DAYS_SUB(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + WEEKS_ADD(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + WEEKS_DIFF(false, DorisDataType.INT, DorisDataType.DATETIME, DorisDataType.DATETIME), + WEEKS_SUB(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + MONTHS_ADD(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + MONTHS_DIFF(false, DorisDataType.INT, DorisDataType.DATETIME, DorisDataType.DATETIME), + MONTHS_SUB(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + YEARS_ADD(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + YEARS_DIFF(false, DorisDataType.INT, DorisDataType.DATETIME, DorisDataType.DATETIME), + YEARS_SUB(false, DorisDataType.DATETIME, DorisDataType.DATETIME, DorisDataType.INT), + + // GIS functions, https://doris.apache.org/docs/dev/sql-manual/sql-functions/spatial-functions/st_x + // Skip now + + // String functions, https://doris.apache.org/docs/dev/sql-manual/sql-functions/string-functions/to_base64 + TO_BASE64(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + FROM_BASE64(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + ASCII(false, DorisDataType.INT, DorisDataType.VARCHAR), LENGTH(false, DorisDataType.INT, DorisDataType.VARCHAR), + BIT_LENGTH(false, DorisDataType.INT, DorisDataType.VARCHAR), + CHAR_LENGTH(false, DorisDataType.INT, DorisDataType.VARCHAR), + LPAD(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT, DorisDataType.VARCHAR), + RPAD(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT, DorisDataType.VARCHAR), + LOWER(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + LCASE(false, DorisDataType.INT, DorisDataType.VARCHAR), + UPPER(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + UCASE(false, DorisDataType.INT, DorisDataType.VARCHAR), + INITCAP(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + REPEAT(false, DorisDataType.VARCHAR, DorisDataType.INT), + REVERSE(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + CHAR(true, DorisDataType.VARCHAR, DorisDataType.INT), + CONCAT(true, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + CONCAT_WS(true, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + SUBSTR(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT, DorisDataType.INT), + SUBSTRING(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT), + SUB_REPLACE(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT), + APPEND_TRAILING_CHAR_IF_ABSENT(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + ENDS_WITH(false, DorisDataType.BOOLEAN, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + STARTS_WITH(false, DorisDataType.BOOLEAN, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + TRIM(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + LTRIM(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + RTRIM(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + NULL_OR_EMPTY(false, DorisDataType.BOOLEAN, DorisDataType.VARCHAR), + NOT_NULL_OR_EMPTY(false, DorisDataType.BOOLEAN, DorisDataType.VARCHAR), + HEX(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + UNHEX(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + ELT(true, DorisDataType.VARCHAR, DorisDataType.INT, DorisDataType.VARCHAR), + INSTR(false, DorisDataType.INT, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + LOCATE(false, DorisDataType.INT, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + // FIELD(1, true), + FIND_IN_SET(false, DorisDataType.INT, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + REPLACE(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + LEFT(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT), + RIGHT(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT), + STRLEFT(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT), + STRRIGHT(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT), + SPLIT_PART(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT), + // SPLIT_BY_STRING(2), + SUBSTRING_INDEX(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.INT), + MONEY_FORMAT(false, DorisDataType.VARCHAR, DorisDataType.DECIMAL), + PARSE_URL(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + CONVERT_TO(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + EXTRACT_URL_PARAMETER(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + UUID(false, DorisDataType.VARCHAR), SPACE(false, DorisDataType.VARCHAR, DorisDataType.INT), + // SLEEP(1), + ESQUERY(false, DorisDataType.BOOLEAN, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + MASK(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + MASK_FIRST_N(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + MASK_LAST_N(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + // MULTI_SEARCH_ALL_POSITIONS(2), + // MULTI_MATCH_ANY(2), + + // BITMAP functions, https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-functions/bitmap-functions/to_bitmap + // skip now + + // Bitwise functions, https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-functions/bitwise-functions/bitand + BITAND(false, DorisDataType.INT, DorisDataType.INT, DorisDataType.INT), + BITOR(false, DorisDataType.INT, DorisDataType.INT, DorisDataType.INT), + BITXOR(false, DorisDataType.INT, DorisDataType.INT, DorisDataType.INT), + BITNOT(false, DorisDataType.INT, DorisDataType.INT), + + // condition funtions + // case(), + COALESCE(true, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + IF(false, DorisDataType.VARCHAR, DorisDataType.BOOLEAN, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + IFNULL(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + NVL(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + NULLIF(false, DorisDataType.VARCHAR, DorisDataType.VARCHAR, DorisDataType.VARCHAR), + + // JSON Functions, https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-functions/json-functions/jsonb_parse + // skip now + + // Hash functions, + // https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-functions/hash-functions/murmur_hash3_32 + MURMUR_HASH3_32(true, DorisDataType.INT, DorisDataType.VARCHAR), + MURMUR_HASH3_64(true, DorisDataType.INT, DorisDataType.VARCHAR), + + // HLL functions, https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-functions/hll-functions/hll_cardinality + // skip now + + // Math functions, https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-functions/math-functions/conv + CONV(false, DorisDataType.VARCHAR, DorisDataType.INT, DorisDataType.INT, DorisDataType.INT), + BIN(false, DorisDataType.VARCHAR, DorisDataType.INT), SIN(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + COS(false, DorisDataType.FLOAT, DorisDataType.FLOAT), TAN(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + ASIN(false, DorisDataType.FLOAT, DorisDataType.FLOAT), ACOS(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + ATAN(false, DorisDataType.FLOAT, DorisDataType.FLOAT), E(false, DorisDataType.FLOAT), + PI(false, DorisDataType.FLOAT), EXP(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + LOG(false, DorisDataType.FLOAT, DorisDataType.FLOAT, DorisDataType.FLOAT), + LOG2(false, DorisDataType.FLOAT, DorisDataType.FLOAT), LN(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + LOG10(false, DorisDataType.FLOAT, DorisDataType.FLOAT), CEIL(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + FLOOR(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + PMOD(false, DorisDataType.FLOAT, DorisDataType.FLOAT, DorisDataType.FLOAT), + ROUND(false, DorisDataType.INT, DorisDataType.FLOAT), + ROUND_BANKERS(false, DorisDataType.FLOAT, DorisDataType.FLOAT, DorisDataType.INT), + TRUNCATE(false, DorisDataType.FLOAT, DorisDataType.FLOAT, DorisDataType.INT), + ABS(false, DorisDataType.FLOAT, DorisDataType.FLOAT), SQRT(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + CBRT(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + POW(false, DorisDataType.FLOAT, DorisDataType.FLOAT, DorisDataType.FLOAT), + DEGREES(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + RADIANS(false, DorisDataType.FLOAT, DorisDataType.FLOAT), SIGN(false, DorisDataType.INT, DorisDataType.FLOAT), + POSTIVE(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + NEGATIVE(false, DorisDataType.FLOAT, DorisDataType.FLOAT), + GREATEST(true, DorisDataType.FLOAT, DorisDataType.FLOAT), LEAST(true, DorisDataType.FLOAT, DorisDataType.FLOAT), + RANDOM(false, DorisDataType.FLOAT), MOD(false, DorisDataType.FLOAT, DorisDataType.FLOAT, DorisDataType.FLOAT); + + // encrypt-digest-functions, + // https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-functions/encrypt-digest-functions/aes + // skip now + + private boolean isVariadic; // If isVALid is true, then treat the last argumentTypes as an infinite type + private DorisDataType returnType; + private DorisDataType[] argumentTypes; + private String functionName; + + DorisFunction(String functionName, boolean isVariadic, DorisDataType returnType, + DorisDataType... argumentTypes) { + this.functionName = functionName; + this.isVariadic = isVariadic; + this.returnType = returnType; + this.argumentTypes = argumentTypes.clone(); + } + + DorisFunction(boolean isVariadic, DorisDataType returnType, DorisDataType... argumentTypes) { + this.functionName = toString(); + this.isVariadic = isVariadic; + this.returnType = returnType; + this.argumentTypes = argumentTypes.clone(); + } + + DorisFunction(boolean isVariadic, DorisDataType returnType) { + this.functionName = toString(); + this.isVariadic = isVariadic; + this.returnType = returnType; + this.argumentTypes = null; + } + + public String getFunctionName() { + return functionName; + } + + public static DorisFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public boolean isVariadic() { + return isVariadic; + } + + public boolean isCompatibleWithReturnType(DorisDataType returnType) { + return this.returnType == returnType; + } + + public DorisDataType[] getArgumentTypes() { + if (argumentTypes == null) { + return null; + } + return argumentTypes.clone(); + } + + public DorisFunctionOperation getCall(DorisDataType returnType, DorisNewExpressionGenerator gen, int depth) { + List arguments = new ArrayList<>(); + if (getArgumentTypes() != null) { + Stream.of(getArgumentTypes()).forEach(arg -> arguments.add(gen.generateExpression(arg, depth + 1))); + } + return new DorisFunctionOperation(this, arguments); + } + + public static List getFunctionsCompatibleWith(DorisDataType returnType) { + return Stream.of(values()).filter(f -> f.isCompatibleWithReturnType(returnType)) + .collect(Collectors.toList()); + } + + } + + public DorisFunctionOperation(DorisFunction function, List args) { + this.function = function; + this.args = args; + } + + public List getArgs() { + return args; + } + + public DorisFunction getFunction() { + return function; + } + +} diff --git a/src/sqlancer/doris/ast/DorisInOperation.java b/src/sqlancer/doris/ast/DorisInOperation.java new file mode 100644 index 000000000..96ccd0998 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisInOperation.java @@ -0,0 +1,52 @@ +package sqlancer.doris.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; +import sqlancer.doris.DorisSchema; + +public class DorisInOperation extends NewInOperatorNode implements DorisExpression { + + private final DorisExpression leftExpr; + private final List rightExpr; + + public DorisInOperation(DorisExpression left, List right, boolean isNegated) { + super(left, right, isNegated); + this.leftExpr = left; + this.rightExpr = right; + } + + @Override + public DorisSchema.DorisDataType getExpectedType() { + return DorisSchema.DorisDataType.BOOLEAN; + } + + @Override + public DorisConstant getExpectedValue() { + DorisConstant leftValue = leftExpr.getExpectedValue(); + if (leftValue == null) { + return null; + } + if (leftValue.isNull()) { + return DorisConstant.createNullConstant(); + } + boolean containNull = false; + for (DorisExpression expr : rightExpr) { + DorisConstant rightValue = expr.getExpectedValue(); + if (rightValue == null) { + return null; + } + if (rightValue.isNull()) { + containNull = true; + } else if (rightValue.valueEquals(leftValue).isBoolean() && rightValue.valueEquals(leftValue).asBoolean()) { + return DorisConstant.createBooleanConstant(!isNegated()); + } + } + + if (containNull) { + return DorisConstant.createNullConstant(); + } + // should return false when not considering isNegated op + return DorisConstant.createBooleanConstant(isNegated()); + } +} diff --git a/src/sqlancer/doris/ast/DorisJoin.java b/src/sqlancer/doris/ast/DorisJoin.java new file mode 100644 index 000000000..c92555ec3 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisJoin.java @@ -0,0 +1,109 @@ +package sqlancer.doris.ast; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisTable; +import sqlancer.doris.gen.DorisNewExpressionGenerator; + +public class DorisJoin implements DorisExpression, Join { + + private final DorisTableReference leftTable; + private final DorisTableReference rightTable; + private final JoinType joinType; + private DorisExpression onCondition; + + public enum JoinType { + INNER, STRAIGHT, LEFT, RIGHT; + + public static JoinType getRandom() { + return Randomly.fromOptions(values()); + } + } + + public DorisJoin(DorisTableReference leftTable, DorisTableReference rightTable, JoinType joinType, + DorisExpression whereCondition) { + this.leftTable = leftTable; + this.rightTable = rightTable; + this.joinType = joinType; + this.onCondition = whereCondition; + } + + public DorisTableReference getLeftTable() { + return leftTable; + } + + public DorisTableReference getRightTable() { + return rightTable; + } + + public JoinType getJoinType() { + return joinType; + } + + public DorisExpression getOnCondition() { + return onCondition; + } + + public static List getJoins(List tableList, DorisGlobalState globalState) { + List joinExpressions = new ArrayList<>(); + while (tableList.size() >= 2 && Randomly.getBooleanWithRatherLowProbability()) { + DorisTableReference leftTable = tableList.remove(0); + DorisTableReference rightTable = tableList.remove(0); + List columns = new ArrayList<>(leftTable.getTable().getColumns()); + columns.addAll(rightTable.getTable().getColumns()); + DorisNewExpressionGenerator joinGen = new DorisNewExpressionGenerator(globalState).setColumns(columns); + switch (DorisJoin.JoinType.getRandom()) { + case INNER: + joinExpressions.add(DorisJoin.createInnerJoin(leftTable, rightTable, + joinGen.generateExpression(DorisSchema.DorisDataType.BOOLEAN))); + break; + case STRAIGHT: + joinExpressions.add(DorisJoin.createStraightJoin(leftTable, rightTable, + joinGen.generateExpression(DorisSchema.DorisDataType.BOOLEAN))); + break; + case LEFT: + joinExpressions.add(DorisJoin.createLeftOuterJoin(leftTable, rightTable, + joinGen.generateExpression(DorisSchema.DorisDataType.BOOLEAN))); + break; + case RIGHT: + joinExpressions.add(DorisJoin.createRightOuterJoin(leftTable, rightTable, + joinGen.generateExpression(DorisSchema.DorisDataType.BOOLEAN))); + break; + default: + throw new AssertionError(); + } + } + return joinExpressions; + } + + public static DorisJoin createInnerJoin(DorisTableReference left, DorisTableReference right, + DorisExpression predicate) { + return new DorisJoin(left, right, JoinType.INNER, predicate); + } + + public static DorisJoin createStraightJoin(DorisTableReference left, DorisTableReference right, + DorisExpression predicate) { + return new DorisJoin(left, right, JoinType.STRAIGHT, predicate); + } + + public static DorisJoin createRightOuterJoin(DorisTableReference left, DorisTableReference right, + DorisExpression predicate) { + return new DorisJoin(left, right, JoinType.RIGHT, predicate); + } + + public static DorisJoin createLeftOuterJoin(DorisTableReference left, DorisTableReference right, + DorisExpression predicate) { + return new DorisJoin(left, right, JoinType.LEFT, predicate); + } + + @Override + public void setOnClause(DorisExpression onClause) { + onCondition = onClause; + } +} diff --git a/src/sqlancer/doris/ast/DorisLikeOperation.java b/src/sqlancer/doris/ast/DorisLikeOperation.java new file mode 100644 index 000000000..128a9299d --- /dev/null +++ b/src/sqlancer/doris/ast/DorisLikeOperation.java @@ -0,0 +1,84 @@ +package sqlancer.doris.ast; + +import sqlancer.LikeImplementationHelper; +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; +import sqlancer.doris.DorisSchema.DorisDataType; + +public class DorisLikeOperation extends NewBinaryOperatorNode implements DorisExpression { + + public DorisLikeOperation(DorisExpression left, DorisExpression right, DorisLikeOperator op) { + super(left, right, op); + } + + @Override + public DorisDataType getExpectedType() { + return DorisDataType.BOOLEAN; + } + + public DorisExpression getLeftExpr() { + return super.getLeft(); + } + + public DorisExpression getRightExpr() { + return super.getRight(); + } + + public DorisLikeOperator getOp() { + return (DorisLikeOperator) op; + } + + @Override + public DorisConstant getExpectedValue() { + DorisConstant leftVal = getLeftExpr().getExpectedValue(); + DorisConstant rightVal = getRightExpr().getExpectedValue(); + if (leftVal == null || rightVal == null) { + return null; + } + return getOp().apply(leftVal, rightVal); + } + + public enum DorisLikeOperator implements BinaryOperatorNode.Operator { + LIKE_OPERATOR("LIKE", "like") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + if (left == null || right == null) { + return null; + } + if (left.isNull() || right.isNull()) { + return DorisConstant.createNullConstant(); + } + boolean result = LikeImplementationHelper.match(left.asString(), right.asString(), 0, 0, true); + return DorisConstant.createBooleanConstant(result); + } + }, + NOT_LIKE("NOT LIKE", "not like") { + @Override + public DorisConstant apply(DorisConstant left, DorisConstant right) { + if (left == null || right == null) { + return null; + } + if (left.isNull() || right.isNull()) { + return DorisConstant.createNullConstant(); + } + boolean result = LikeImplementationHelper.match(left.asString(), right.asString(), 0, 0, true); + return DorisConstant.createBooleanConstant(!result); + } + }; + + private final String[] textRepresentations; + + DorisLikeOperator(String... text) { + textRepresentations = text.clone(); + } + + public abstract DorisConstant apply(DorisConstant left, DorisConstant right); + + @Override + public String getTextRepresentation() { + return " " + Randomly.fromOptions(textRepresentations) + " "; + } + } + +} diff --git a/src/sqlancer/doris/ast/DorisOrderByTerm.java b/src/sqlancer/doris/ast/DorisOrderByTerm.java new file mode 100644 index 000000000..01a7ab6d2 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisOrderByTerm.java @@ -0,0 +1,9 @@ +package sqlancer.doris.ast; + +import sqlancer.common.ast.newast.NewOrderingTerm; + +public class DorisOrderByTerm extends NewOrderingTerm implements DorisExpression { + public DorisOrderByTerm(DorisExpression expr, Ordering ordering) { + super(expr, ordering); + } +} diff --git a/src/sqlancer/doris/ast/DorisPostfixText.java b/src/sqlancer/doris/ast/DorisPostfixText.java new file mode 100644 index 000000000..889f4700b --- /dev/null +++ b/src/sqlancer/doris/ast/DorisPostfixText.java @@ -0,0 +1,9 @@ +package sqlancer.doris.ast; + +import sqlancer.common.ast.newast.NewPostfixTextNode; + +public class DorisPostfixText extends NewPostfixTextNode implements DorisExpression { + public DorisPostfixText(DorisExpression expr, String text) { + super(expr, text); + } +} diff --git a/src/sqlancer/doris/ast/DorisSelect.java b/src/sqlancer/doris/ast/DorisSelect.java new file mode 100644 index 000000000..4921e0688 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisSelect.java @@ -0,0 +1,64 @@ +package sqlancer.doris.ast; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisTable; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public class DorisSelect extends SelectBase + implements DorisExpression, Select { + + public enum DorisSelectDistinctType { + + ALL, DISTINCT, DISTINCTROW, NULL; + + public static DorisSelectDistinctType getRandomWithoutNull() { + DorisSelectDistinctType sft; + do { + sft = Randomly.fromOptions(values()); + } while (sft == DorisSelectDistinctType.NULL); + return sft; + } + } + + private DorisSelectDistinctType selectDistinctType = DorisSelectDistinctType.ALL; + + public void setDistinct(boolean isDistinct) { + if (isDistinct) { + this.selectDistinctType = DorisSelectDistinctType.DISTINCT; + } else { + this.selectDistinctType = DorisSelectDistinctType.ALL; + } + } + + public void setDistinct(DorisSelectDistinctType type) { + this.selectDistinctType = type; + } + + public boolean isDistinct() { + return this.selectDistinctType == DorisSelectDistinctType.DISTINCT + || this.selectDistinctType == DorisSelectDistinctType.DISTINCTROW; + } + + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (DorisExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (DorisJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return DorisToStringVisitor.asString(this); + } +} diff --git a/src/sqlancer/doris/ast/DorisTableReference.java b/src/sqlancer/doris/ast/DorisTableReference.java new file mode 100644 index 000000000..b8a2cc2a2 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisTableReference.java @@ -0,0 +1,11 @@ +package sqlancer.doris.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.doris.DorisSchema; + +public class DorisTableReference extends TableReferenceNode + implements DorisExpression { + public DorisTableReference(DorisSchema.DorisTable table) { + super(table); + } +} diff --git a/src/sqlancer/doris/ast/DorisUnaryPostfixOperation.java b/src/sqlancer/doris/ast/DorisUnaryPostfixOperation.java new file mode 100644 index 000000000..f6961c7b7 --- /dev/null +++ b/src/sqlancer/doris/ast/DorisUnaryPostfixOperation.java @@ -0,0 +1,86 @@ +package sqlancer.doris.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; +import sqlancer.doris.DorisSchema.DorisDataType; + +public class DorisUnaryPostfixOperation extends NewUnaryPostfixOperatorNode + implements DorisExpression { + + public DorisUnaryPostfixOperation(DorisExpression expr, DorisUnaryPostfixOperator op) { + super(expr, op); + } + + public DorisExpression getExpression() { + return getExpr(); + } + + public DorisUnaryPostfixOperator getOp() { + return (DorisUnaryPostfixOperator) op; + } + + @Override + public DorisDataType getExpectedType() { + return DorisDataType.BOOLEAN; + } + + @Override + public DorisConstant getExpectedValue() { + DorisConstant expectedValue = getExpression().getExpectedValue(); + if (expectedValue == null) { + return null; + } + return getOp().apply(expectedValue); + } + + public enum DorisUnaryPostfixOperator implements BinaryOperatorNode.Operator { + IS_NULL("IS NULL") { + @Override + public DorisDataType[] getInputDataTypes() { + return DorisDataType.values(); + } + + @Override + public DorisConstant apply(DorisConstant value) { + return DorisConstant.createBooleanConstant(value.isNull()); + } + }, + IS_NOT_NULL("IS NOT NULL") { + @Override + public DorisDataType[] getInputDataTypes() { + return DorisDataType.values(); + } + + @Override + public DorisConstant apply(DorisConstant value) { + return DorisConstant.createBooleanConstant(!value.isNull()); + } + }; + + private final String textRepresentations; + + DorisUnaryPostfixOperator(String text) { + this.textRepresentations = text; + } + + public static DorisUnaryPostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentations; + } + + public abstract DorisDataType[] getInputDataTypes(); + + public abstract DorisConstant apply(DorisConstant value); + } + + @Override + public String getOperatorRepresentation() { + return this.op.getTextRepresentation(); + } + +} diff --git a/src/sqlancer/doris/ast/DorisUnaryPrefixOperation.java b/src/sqlancer/doris/ast/DorisUnaryPrefixOperation.java new file mode 100644 index 000000000..ffd919fdc --- /dev/null +++ b/src/sqlancer/doris/ast/DorisUnaryPrefixOperation.java @@ -0,0 +1,111 @@ +package sqlancer.doris.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; +import sqlancer.doris.DorisSchema.DorisDataType; + +public class DorisUnaryPrefixOperation extends NewUnaryPrefixOperatorNode implements DorisExpression { + + public DorisUnaryPrefixOperation(DorisExpression expr, DorisUnaryPrefixOperator op) { + super(expr, op); + } + + public DorisExpression getExpression() { + return getExpr(); + } + + public DorisUnaryPrefixOperator getOp() { + return (DorisUnaryPrefixOperator) op; + } + + @Override + public DorisDataType getExpectedType() { + return getOp().getExpressionType(getExpression()); + } + + @Override + public DorisConstant getExpectedValue() { + DorisConstant expectedValue = getExpression().getExpectedValue(); + if (expectedValue == null) { + return null; + } + return getOp().apply(expectedValue); + } + + public enum DorisUnaryPrefixOperator implements BinaryOperatorNode.Operator { + NOT("NOT", DorisDataType.BOOLEAN, DorisDataType.INT) { + @Override + public DorisDataType getExpressionType(DorisExpression expr) { + return DorisDataType.BOOLEAN; + } + + @Override + protected DorisConstant apply(DorisConstant value) { + if (value.isNull()) { + return DorisConstant.createNullConstant(); + } else { + return DorisConstant.createBooleanConstant(!value.cast(DorisDataType.BOOLEAN).asBoolean()); + } + } + }, + + UNARY_PLUS("+", DorisDataType.INT) { + @Override + public DorisDataType getExpressionType(DorisExpression expr) { + return expr.getExpectedType(); + } + + @Override + protected DorisConstant apply(DorisConstant value) { + return value; + } + }, + UNARY_MINUS("-", DorisDataType.INT) { + @Override + public DorisDataType getExpressionType(DorisExpression expr) { + return expr.getExpectedType(); + } + + @Override + protected DorisConstant apply(DorisConstant value) { + if (value.isNull()) { + return DorisConstant.createNullConstant(); + } + try { + if (value.isInt()) { + return DorisConstant.createIntConstant(-value.asInt()); + } + if (value.isFloat()) { + return DorisConstant.createFloatConstant(-value.asFloat()); + } + return null; + } catch (UnsupportedOperationException e) { + return null; + } + } + }; + + private String textRepresentation; + private DorisDataType[] dataTypes; + + DorisUnaryPrefixOperator(String textRepresentation, DorisDataType... dataTypes) { + this.textRepresentation = textRepresentation; + this.dataTypes = dataTypes.clone(); + } + + public abstract DorisDataType getExpressionType(DorisExpression expr); + + public DorisDataType getRandomInputDataTypes() { + return Randomly.fromOptions(dataTypes); + } + + protected abstract DorisConstant apply(DorisConstant value); + + @Override + public String getTextRepresentation() { + return this.textRepresentation; + } + } + +} diff --git a/src/sqlancer/doris/gen/DorisAlterTableGenerator.java b/src/sqlancer/doris/gen/DorisAlterTableGenerator.java new file mode 100644 index 000000000..bfac074fc --- /dev/null +++ b/src/sqlancer/doris/gen/DorisAlterTableGenerator.java @@ -0,0 +1,50 @@ +package sqlancer.doris.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema.DorisCompositeDataType; +import sqlancer.doris.DorisSchema.DorisTable; + +public final class DorisAlterTableGenerator { + + private DorisAlterTableGenerator() { + } + + enum Action { + ADD_COLUMN, ALTER_COLUMN, DROP_COLUMN + } + + public static SQLQueryAdapter getQuery(DorisGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder("ALTER TABLE "); + DorisTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + sb.append(" "); + Action action = Randomly.fromOptions(Action.values()); + switch (action) { + case ADD_COLUMN: + sb.append("ADD COLUMN "); + String columnName = table.getFreeColumnName(); + sb.append(columnName); + sb.append(" "); + sb.append(DorisCompositeDataType.getRandomWithoutNull().toString()); + break; + case ALTER_COLUMN: + sb.append("MODIFY COLUMN "); + sb.append(table.getRandomColumn().getName()); + sb.append(" "); + sb.append(DorisCompositeDataType.getRandomWithoutNull().toString()); + break; + case DROP_COLUMN: + sb.append("DROP COLUMN "); + sb.append(table.getRandomColumn().getName()); + break; + default: + throw new AssertionError(action); + } + return new SQLQueryAdapter(sb.toString(), errors, true); + } + +} diff --git a/src/sqlancer/doris/gen/DorisDeleteGenerator.java b/src/sqlancer/doris/gen/DorisDeleteGenerator.java new file mode 100644 index 000000000..27f369aec --- /dev/null +++ b/src/sqlancer/doris/gen/DorisDeleteGenerator.java @@ -0,0 +1,31 @@ +package sqlancer.doris.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema; +import sqlancer.doris.DorisSchema.DorisTable; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public final class DorisDeleteGenerator { + + private DorisDeleteGenerator() { + } + + public static SQLQueryAdapter generate(DorisGlobalState globalState) { + StringBuilder sb = new StringBuilder("DELETE FROM "); + ExpectedErrors errors = new ExpectedErrors(); + DorisTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + sb.append(DorisToStringVisitor.asString(new DorisNewExpressionGenerator(globalState) + .setColumns(table.getColumns()).generateExpression(DorisSchema.DorisDataType.BOOLEAN))); + DorisErrors.addExpressionErrors(errors); + } + return new SQLQueryAdapter(sb.toString(), errors); + } + +} diff --git a/src/sqlancer/doris/gen/DorisDropTableGenerator.java b/src/sqlancer/doris/gen/DorisDropTableGenerator.java new file mode 100644 index 000000000..c8bbc67d4 --- /dev/null +++ b/src/sqlancer/doris/gen/DorisDropTableGenerator.java @@ -0,0 +1,28 @@ +package sqlancer.doris.gen; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisProvider.DorisGlobalState; + +public final class DorisDropTableGenerator { + + private DorisDropTableGenerator() { + } + + public static SQLQueryAdapter dropTable(DorisGlobalState globalState) { + if (globalState.getSchema().getTables(t -> !t.isView()).size() <= 1) { + throw new IgnoreMeException(); + } + StringBuilder sb = new StringBuilder("DROP TABLE "); + if (Randomly.getBoolean()) { + sb.append("IF EXISTS "); + } + sb.append(globalState.getSchema().getRandomTableOrBailout(t -> !t.isView()).getName()); + if (Randomly.getBoolean()) { + sb.append(" FORCE "); + } + return new SQLQueryAdapter(sb.toString(), null, true); + } + +} diff --git a/src/sqlancer/doris/gen/DorisDropViewGenerator.java b/src/sqlancer/doris/gen/DorisDropViewGenerator.java new file mode 100644 index 000000000..4f87ba88d --- /dev/null +++ b/src/sqlancer/doris/gen/DorisDropViewGenerator.java @@ -0,0 +1,27 @@ +package sqlancer.doris.gen; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisProvider.DorisGlobalState; + +public final class DorisDropViewGenerator { + + private DorisDropViewGenerator() { + } + + public static SQLQueryAdapter dropView(DorisGlobalState globalState) { + if (globalState.getSchema().getTables(t -> t.isView()).isEmpty()) { + throw new IgnoreMeException(); + } + StringBuilder sb = new StringBuilder("DROP VIEW "); + if (Randomly.getBoolean()) { + sb.append("IF EXISTS "); + } + // TODO: DROP VIEW syntax: DROP MATERIALIZED VIEW [IF EXISTS] mv_name ON table_name; + // should record original table name in view table + sb.append(globalState.getSchema().getRandomTableOrBailout(t -> t.isView()).getName()); + return new SQLQueryAdapter(sb.toString(), null, true); + } + +} diff --git a/src/sqlancer/doris/gen/DorisIndexGenerator.java b/src/sqlancer/doris/gen/DorisIndexGenerator.java new file mode 100644 index 000000000..5e56bb192 --- /dev/null +++ b/src/sqlancer/doris/gen/DorisIndexGenerator.java @@ -0,0 +1,46 @@ +package sqlancer.doris.gen; + +import java.sql.SQLException; +import java.util.List; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisTable; + +public final class DorisIndexGenerator { + + private DorisIndexGenerator() { + } + + public static SQLQueryAdapter getQuery(DorisGlobalState globalState) throws SQLException { + if (globalState.getSchema().getIndexCount() > globalState.getDbmsSpecificOptions().maxNumIndexes) { + throw new IgnoreMeException(); + } + ExpectedErrors errors = new ExpectedErrors(); + + DorisTable randomTable = globalState.getSchema().getRandomTable(t -> !t.isView()); + String indexName = globalState.getSchema().getFreeIndexName(); + StringBuilder sb = new StringBuilder("CREATE "); + sb.append("INDEX "); + if (Randomly.getBoolean()) { + sb.append("IF NOT EXISTS "); + } + sb.append(indexName); + sb.append(" ON "); + sb.append(randomTable.getName()); + sb.append("("); + int nr = 1; // Doris Only support CREATE_INDEX on single column and index type is BITMAP; + List subset = Randomly.extractNrRandomColumns(randomTable.getColumns(), nr); + sb.append(subset.get(0).getName()); + sb.append(") "); + if (Randomly.getBoolean()) { + sb.append("USING BITMAP "); + } + return new SQLQueryAdapter(sb.toString(), errors, true); + } + +} diff --git a/src/sqlancer/doris/gen/DorisInsertGenerator.java b/src/sqlancer/doris/gen/DorisInsertGenerator.java new file mode 100644 index 000000000..50dc5cdec --- /dev/null +++ b/src/sqlancer/doris/gen/DorisInsertGenerator.java @@ -0,0 +1,54 @@ +package sqlancer.doris.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.gen.AbstractInsertGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisTable; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public class DorisInsertGenerator extends AbstractInsertGenerator { + + private final DorisGlobalState globalState; + private final ExpectedErrors errors = new ExpectedErrors(); + + public DorisInsertGenerator(DorisGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter getQuery(DorisGlobalState globalState) { + return new DorisInsertGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + sb.append("INSERT INTO "); + DorisTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyInsertColumns(); + sb.append(table.getName()); + sb.append(" ("); + sb.append(columns.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + sb.append(")"); + sb.append(" VALUES "); + insertColumns(columns); + DorisErrors.addInsertErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } + + @Override + protected void insertValue(DorisColumn column) { + if (column.hasDefaultValue() && Randomly.getBooleanWithRatherLowProbability()) { + sb.append("DEFAULT"); + } else { + String value = DorisToStringVisitor.asString(new DorisNewExpressionGenerator(globalState) + .generateConstant(column.getType().getPrimitiveDataType(), column.isNullable())); // 生成一个与column相同的常量类型 + sb.append(value); + } + } + +} diff --git a/src/sqlancer/doris/gen/DorisNewExpressionGenerator.java b/src/sqlancer/doris/gen/DorisNewExpressionGenerator.java new file mode 100644 index 000000000..bddc6cc5a --- /dev/null +++ b/src/sqlancer/doris/gen/DorisNewExpressionGenerator.java @@ -0,0 +1,545 @@ +package sqlancer.doris.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.ast.newast.NewOrderingTerm; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.gen.TypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; +import sqlancer.doris.DorisBugs; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisCompositeDataType; +import sqlancer.doris.DorisSchema.DorisDataType; +import sqlancer.doris.DorisSchema.DorisRowValue; +import sqlancer.doris.DorisSchema.DorisTable; +import sqlancer.doris.ast.DorisAggregateOperation; +import sqlancer.doris.ast.DorisAggregateOperation.DorisAggregateFunction; +import sqlancer.doris.ast.DorisBetweenOperation; +import sqlancer.doris.ast.DorisBinaryArithmeticOperation; +import sqlancer.doris.ast.DorisBinaryArithmeticOperation.DorisBinaryArithmeticOperator; +import sqlancer.doris.ast.DorisBinaryComparisonOperation; +import sqlancer.doris.ast.DorisBinaryComparisonOperation.DorisBinaryComparisonOperator; +import sqlancer.doris.ast.DorisBinaryLogicalOperation; +import sqlancer.doris.ast.DorisBinaryLogicalOperation.DorisBinaryLogicalOperator; +import sqlancer.doris.ast.DorisCaseOperation; +import sqlancer.doris.ast.DorisCastOperation; +import sqlancer.doris.ast.DorisColumnReference; +import sqlancer.doris.ast.DorisColumnValue; +import sqlancer.doris.ast.DorisConstant; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.ast.DorisFunctionOperation.DorisFunction; +import sqlancer.doris.ast.DorisInOperation; +import sqlancer.doris.ast.DorisJoin; +import sqlancer.doris.ast.DorisLikeOperation; +import sqlancer.doris.ast.DorisOrderByTerm; +import sqlancer.doris.ast.DorisPostfixText; +import sqlancer.doris.ast.DorisSelect; +import sqlancer.doris.ast.DorisTableReference; +import sqlancer.doris.ast.DorisUnaryPostfixOperation; +import sqlancer.doris.ast.DorisUnaryPostfixOperation.DorisUnaryPostfixOperator; +import sqlancer.doris.ast.DorisUnaryPrefixOperation; +import sqlancer.doris.ast.DorisUnaryPrefixOperation.DorisUnaryPrefixOperator; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public class DorisNewExpressionGenerator extends TypedExpressionGenerator + implements NoRECGenerator, + TLPWhereGenerator { + + private final DorisGlobalState globalState; + private List tables; + + private final int maxDepth; + private boolean allowAggregateFunctions; + private DorisRowValue rowValue; + + private Set columnOfLeafNode; + + public DorisNewExpressionGenerator setRowValue(DorisRowValue rowValue) { + this.rowValue = rowValue; + return this; + } + + public void setColumnOfLeafNode(Set columnOfLeafNode) { + this.columnOfLeafNode = columnOfLeafNode; + } + + public DorisNewExpressionGenerator(DorisGlobalState globalState) { + this.globalState = globalState; + this.maxDepth = globalState.getOptions().getMaxExpressionDepth(); + } + + @Override + public DorisExpression generateLeafNode(DorisDataType dataType) { + if (Randomly.getBoolean()) { + return generateConstant(dataType); + } else { + if (filterColumns(dataType).isEmpty()) { + return generateConstant(dataType); + } else { + return createColumnOfType(dataType); + } + } + } + + final List filterColumns(DorisDataType dataType) { + if (columns == null) { + return Collections.emptyList(); + } else { + return columns.stream().filter(c -> c.getType().getPrimitiveDataType() == dataType) + .collect(Collectors.toList()); + } + } + + private DorisExpression createColumnOfType(DorisDataType type) { + List columns = filterColumns(type); + DorisColumn column = Randomly.fromList(columns); + DorisConstant value = rowValue == null ? null : rowValue.getValues().get(column); + if (columnOfLeafNode != null) { + columnOfLeafNode.add(DorisColumnValue.create(column, value)); + } + return DorisColumnValue.create(column, value); + } + + public List generateOrderBy() { + List randomColumns = Randomly.subset(columns); + return randomColumns.stream() + .map(c -> new DorisOrderByTerm(new DorisColumnValue(c, null), NewOrderingTerm.Ordering.getRandom())) + .collect(Collectors.toList()); + } + + @Override + public DorisExpression generateExpression(DorisDataType type, int depth) { + // todo: case operation should be add into generateExpression + + if (Randomly.getBooleanWithRatherLowProbability() || depth >= maxDepth) { + return generateLeafNode(type); + } + + if (globalState.getDbmsSpecificOptions().testFunctions && Randomly.getBooleanWithRatherLowProbability()) { + List applicableFunctions = DorisFunction.getFunctionsCompatibleWith(type); + if (!applicableFunctions.isEmpty()) { + DorisFunction function = Randomly.fromList(applicableFunctions); + return function.getCall(type, this, depth + 1); + } + } + if (!DorisBugs.bug36070 && type != DorisDataType.NULL && globalState.getDbmsSpecificOptions().testCasts + && Randomly.getBooleanWithRatherLowProbability()) { + return new DorisCastOperation(generateExpression(getRandomType(), depth + 1), type); + } + if (!DorisBugs.bug36070 && globalState.getDbmsSpecificOptions().testCase + && Randomly.getBooleanWithRatherLowProbability()) { + DorisExpression expr = generateExpression(DorisDataType.BOOLEAN, depth + 1); + List conditions = new ArrayList<>(); + List cases = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + conditions.add(generateExpression(DorisDataType.BOOLEAN, depth + 1)); + cases.add(generateExpression(type, depth + 1)); + } + DorisExpression elseExpr = null; + if (Randomly.getBoolean()) { + elseExpr = generateExpression(type, depth + 1); + } + return new DorisCaseOperation(expr, conditions, cases, elseExpr); + } + + switch (type) { + case INT: + return generateIntExpression(depth); + case BOOLEAN: + return generateBooleanExpression(depth); + case FLOAT: + case DECIMAL: + case DATE: + case DATETIME: + case VARCHAR: + case NULL: + return generateConstant(type); + default: + throw new AssertionError(); + } + } + + public List generateExpressions(int nr, DorisDataType type) { + List expressions = new ArrayList<>(); + for (int i = 0; i < nr; i++) { + expressions.add(generateExpression(type)); + } + return expressions; + } + + private enum IntExpression { + UNARY_OPERATION, BINARY_ARITHMETIC_OPERATION + } + + private DorisExpression generateIntExpression(int depth) { + if (allowAggregateFunctions) { + allowAggregateFunctions = false; + } + IntExpression intExpression = Randomly.fromOptions(IntExpression.values()); + switch (intExpression) { + case UNARY_OPERATION: + return new DorisUnaryPrefixOperation(generateExpression(DorisDataType.INT, depth + 1), + Randomly.getBoolean() ? DorisUnaryPrefixOperator.UNARY_PLUS : DorisUnaryPrefixOperator.UNARY_MINUS); + case BINARY_ARITHMETIC_OPERATION: + return new DorisBinaryArithmeticOperation(generateExpression(DorisDataType.INT, depth + 1), + generateExpression(DorisDataType.INT, depth + 1), + Randomly.fromOptions(DorisBinaryArithmeticOperator.values())); + default: + throw new AssertionError(); + } + } + + private enum BooleanExpression { + POSTFIX_OPERATOR, NOT, BINARY_LOGICAL_OPERATOR, BINARY_COMPARISON, LIKE, BETWEEN, IN_OPERATION; + // SIMILAR_TO, POSIX_REGEX, BINARY_RANGE_COMPARISON,FUNCTION, CAST,; + } + + DorisExpression generateBooleanExpression(int depth) { + if (allowAggregateFunctions) { + allowAggregateFunctions = false; + } + List validOptions = new ArrayList<>(Arrays.asList(BooleanExpression.values())); + if (DorisBugs.bug36346 || !globalState.getDbmsSpecificOptions().testIn) { + validOptions.remove(BooleanExpression.IN_OPERATION); + } + if (!globalState.getDbmsSpecificOptions().testBinaryLogicals) { + validOptions.remove(BooleanExpression.BINARY_LOGICAL_OPERATOR); + } + if (!globalState.getDbmsSpecificOptions().testBinaryComparisons) { + validOptions.remove(BooleanExpression.BINARY_COMPARISON); + } + if (DorisBugs.bug36070 || !globalState.getDbmsSpecificOptions().testBetween) { + validOptions.remove(BooleanExpression.BETWEEN); + } + + BooleanExpression option = Randomly.fromList(validOptions); + switch (option) { + case POSTFIX_OPERATOR: + return getPostfix(depth + 1); + case NOT: + return getNOT(depth + 1); + case BETWEEN: + return getBetween(depth + 1); + case IN_OPERATION: + return getIn(depth + 1); + case BINARY_LOGICAL_OPERATOR: + return getBinaryLogical(depth + 1, DorisDataType.BOOLEAN); + case BINARY_COMPARISON: + return getComparison(depth + 1); + case LIKE: + return getLike(depth + 1, DorisDataType.VARCHAR); + default: + throw new AssertionError(); + } + + } + + DorisExpression getPostfix(int depth) { + DorisUnaryPostfixOperator randomOp = DorisUnaryPostfixOperator.getRandom(); + return new DorisUnaryPostfixOperation( + generateExpression(Randomly.fromOptions(randomOp.getInputDataTypes()), depth), randomOp); + } + + DorisExpression getNOT(int depth) { + DorisUnaryPrefixOperator op = DorisUnaryPrefixOperator.NOT; + return new DorisUnaryPrefixOperation(generateExpression(op.getRandomInputDataTypes(), depth), op); + } + + DorisExpression getBetween(int depth) { + DorisDataType dataType = Randomly.fromList(Arrays.asList(DorisDataType.values()).stream() + .filter(t -> t != DorisDataType.BOOLEAN).collect(Collectors.toList())); + + return new DorisBetweenOperation(generateExpression(dataType, depth), generateExpression(dataType, depth), + generateExpression(dataType, depth), Randomly.getBoolean()); + } + + DorisExpression getIn(int depth) { + DorisDataType dataType = Randomly.fromOptions(DorisDataType.values()); + DorisExpression leftExpr = generateExpression(dataType, depth); + List rightExprs = new ArrayList<>(); + int nr = Randomly.smallNumber() + 1; + for (int i = 0; i < nr; i++) { + rightExprs.add(generateExpression(dataType, depth)); + } + return new DorisInOperation(leftExpr, rightExprs, Randomly.getBoolean()); + } + + DorisExpression getBinaryLogical(int depth, DorisDataType dataType) { + DorisExpression expr = generateExpression(dataType, depth); + int nr = Randomly.smallNumber() + 1; + for (int i = 0; i < nr; i++) { + expr = new DorisBinaryLogicalOperation(expr, generateExpression(DorisDataType.BOOLEAN, depth), + DorisBinaryLogicalOperator.getRandom()); + } + return expr; + } + + DorisExpression getComparison(int depth) { + // 跳过boolean + DorisDataType dataType = Randomly.fromList(Arrays.asList(DorisDataType.values()).stream() + .filter(t -> t != DorisDataType.BOOLEAN).collect(Collectors.toList())); + DorisExpression leftExpr = generateExpression(dataType, depth); + DorisExpression rightExpr = generateExpression(dataType, depth); + return new DorisBinaryComparisonOperation(leftExpr, rightExpr, + Randomly.fromOptions(DorisBinaryComparisonOperator.values())); + } + + DorisExpression getLike(int depth, DorisDataType dataType) { + return new DorisLikeOperation(generateExpression(dataType, depth), generateExpression(dataType, depth), + DorisLikeOperation.DorisLikeOperator.LIKE_OPERATOR); + } + + public DorisExpression generateExpressionWithExpectedResult(DorisDataType type) { + DorisExpression expr; + do { + expr = this.generateExpression(type); + } while (expr.getExpectedValue() == null); + return expr; + } + + @Override + public DorisExpression generatePredicate() { + return generateExpression(DorisDataType.BOOLEAN); + } + + @Override + public DorisExpression negatePredicate(DorisExpression predicate) { + return new DorisUnaryPrefixOperation(predicate, DorisUnaryPrefixOperator.NOT); + } + + @Override + public DorisExpression isNull(DorisExpression predicate) { + return new DorisUnaryPostfixOperation(predicate, DorisUnaryPostfixOperator.IS_NULL); + } + + public DorisExpression generateConstant(DorisDataType type, boolean isNullable) { + if (!isNullable) { + return createConstantWithoutNull(type); + } + if (Randomly.getBooleanWithSmallProbability()) { + return createConstant(DorisDataType.NULL); + } + return createConstant(type); + } + + @Override + public DorisExpression generateConstant(DorisDataType type) { + if (Randomly.getBooleanWithSmallProbability()) { + return DorisConstant.createNullConstant(); + } + return createConstant(type); + } + + public DorisExpression createConstantWithoutNull(DorisDataType type) { + DorisExpression constant = createConstant(type); + int loopCount = 0; + while (constant instanceof DorisConstant.DorisNullConstant && loopCount < 1000) { + constant = createConstant(type); + loopCount++; + } + if (constant instanceof DorisConstant.DorisNullConstant) { + throw new IgnoreMeException(); + } + return constant; + } + + public DorisExpression createConstant(DorisDataType type) { + Randomly r = globalState.getRandomly(); + long timestamp; + switch (type) { + case INT: + if (globalState.getDbmsSpecificOptions().testIntConstants) { + long number = r.getInteger(); + if (DorisBugs.bug36351 && number == -1049190528) { + number = 0; + } + return DorisConstant.createIntConstant(r.getInteger()); + } + return DorisConstant.createNullConstant(); + case BOOLEAN: + if (globalState.getDbmsSpecificOptions().testBooleanConstants) { + return DorisConstant.createBooleanConstant(Randomly.getBoolean()); + } + return DorisConstant.createNullConstant(); + case DECIMAL: + if (globalState.getDbmsSpecificOptions().testDecimalConstants) { + double v = r.getDouble(); + while (v == Double.MAX_VALUE || v == -Double.MAX_VALUE || v == Double.POSITIVE_INFINITY + || v == Double.NEGATIVE_INFINITY) { + v = r.getDouble(); + } + + // e.g. format 1234.413232532 to qualify num 34.4132 + String formatter = "%." + type.getDecimalScale() + "f"; + String vStr = String.format(formatter, v); + int pointPos = vStr.indexOf('.'); + if (pointPos > type.getDecimalPrecision() - type.getDecimalScale()) { + vStr = vStr.substring(pointPos - (type.getDecimalPrecision() - type.getDecimalScale())); + } + return DorisConstant.createFloatConstant(Double.parseDouble(vStr)); + } + return DorisConstant.createNullConstant(); + case FLOAT: + if (globalState.getDbmsSpecificOptions().testFloatConstants) { + return DorisConstant.createFloatConstant((float) r.getDouble()); + } + return DorisConstant.createNullConstant(); + case DATE: + if (globalState.getDbmsSpecificOptions().testDateConstants) { + // [1970-01-01 08:00:00, 3000-01-01 00:00:00] + timestamp = globalState.getRandomly().getLong(0, 32503651200L); + return DorisConstant.createDateConstant(timestamp); + } + return DorisConstant.createNullConstant(); + case DATETIME: + if (globalState.getDbmsSpecificOptions().testDateTimeConstants) { + // [1970-01-01 08:00:00, 3000-01-01 00:00:00] + timestamp = globalState.getRandomly().getLong(0, 32503651200L); + if (DorisBugs.bug36342) { + return DorisConstant.createDatetimeConstant(timestamp); + } + return Randomly.fromOptions(DorisConstant.createDatetimeConstant(timestamp), + DorisConstant.createDatetimeConstant()); + } + return DorisConstant.createNullConstant(); + case VARCHAR: + if (globalState.getDbmsSpecificOptions().testStringConstants) { + String s = r.getString(); + if (s.length() > type.getVarcharLength()) { + s = s.substring(0, type.getVarcharLength()); + } + return DorisConstant.createStringConstant(s); + } + return DorisConstant.createNullConstant(); + case NULL: + return DorisConstant.createNullConstant(); + default: + throw new AssertionError(type); + } + } + + @Override + protected DorisExpression generateColumn(DorisDataType type) { + return null; + } + + @Override + protected DorisDataType getRandomType() { + return Randomly.fromOptions(DorisDataType.values()); + } + + @Override + protected boolean canGenerateColumnOfType(DorisDataType type) { + return false; + } + + public DorisExpression generateArgsForAggregate(DorisAggregateFunction aggregateFunction) { + DorisDataType dataType = Randomly.fromOptions(DorisDataType.values()); + return new DorisAggregateOperation(generateExpressions(aggregateFunction.getNrArgs(), dataType), + aggregateFunction); + } + + public DorisExpression generateAggregate() { + DorisAggregateFunction aggrFunc = DorisAggregateFunction.getRandom(); + return generateArgsForAggregate(aggrFunc); + } + + public DorisExpression generateHavingClause() { + allowAggregateFunctions = true; + DorisExpression expression = generateExpression(DorisDataType.BOOLEAN); + allowAggregateFunctions = false; + return expression; + } + + public void setAllowAggregateFunctions(boolean allowAggregateFunctions) { + this.allowAggregateFunctions = allowAggregateFunctions; + } + + @Override + public DorisNewExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public DorisExpression generateBooleanExpression() { + return generateExpression(DorisDataType.BOOLEAN); + } + + @Override + public DorisSelect generateSelect() { + return new DorisSelect(); + } + + @Override + public List getRandomJoinClauses() { + List tableList = tables.stream().map(t -> new DorisTableReference(t)) + .collect(Collectors.toList()); + List joins = DorisJoin.getJoins(tableList, globalState); + tables = tableList.stream().map(t -> t.getTable()).collect(Collectors.toList()); + return joins; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new DorisTableReference(t)).collect(Collectors.toList()); + } + + @Override + public String generateOptimizedQueryString(DorisSelect select, DorisExpression whereCondition, + boolean shouldUseAggregate) { + if (shouldUseAggregate) { + DorisExpression aggr = new DorisAggregateOperation( + List.of(new DorisColumnReference( + new DorisColumn("*", new DorisCompositeDataType(DorisDataType.INT, 0), false, false))), + DorisAggregateFunction.COUNT); + select.setFetchColumns(List.of(aggr)); + + } else { + List allColumns = columns.stream().map((c) -> new DorisColumnReference(c)) + .collect(Collectors.toList()); + select.setFetchColumns(allColumns); + if (Randomly.getBooleanWithSmallProbability()) { + List constants = new ArrayList<>(); + constants.add(new DorisConstant.DorisIntConstant( + Randomly.smallNumber() % select.getFetchColumns().size() + 1)); + select.setOrderByClauses(constants); + } + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(DorisSelect select, DorisExpression whereCondition) { + DorisExpression asText = new DorisPostfixText(new DorisCastOperation( + new DorisPostfixText(whereCondition, + " IS NOT NULL AND " + DorisToStringVisitor.asString(whereCondition)), + new DorisCompositeDataType(DorisDataType.INT, 8)), "as count"); + select.setFetchColumns(Arrays.asList(asText)); + select.setWhereClause(null); + + return "SELECT SUM(count) FROM (" + select.asString() + ") as res"; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy) { + return List.of(new DorisColumnReference(new DorisColumn("*", null, false, false))); + } + return Randomly.nonEmptySubset(columns).stream().map(c -> new DorisColumnReference(c)) + .collect(Collectors.toList()); + } +} diff --git a/src/sqlancer/doris/gen/DorisRandomQuerySynthesizer.java b/src/sqlancer/doris/gen/DorisRandomQuerySynthesizer.java new file mode 100644 index 000000000..e1ec50eb2 --- /dev/null +++ b/src/sqlancer/doris/gen/DorisRandomQuerySynthesizer.java @@ -0,0 +1,77 @@ +package sqlancer.doris.gen; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema; +import sqlancer.doris.DorisSchema.DorisTable; +import sqlancer.doris.DorisSchema.DorisTables; +import sqlancer.doris.ast.DorisColumnValue; +import sqlancer.doris.ast.DorisConstant; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.ast.DorisJoin; +import sqlancer.doris.ast.DorisSelect; +import sqlancer.doris.ast.DorisTableReference; + +public final class DorisRandomQuerySynthesizer { + + private DorisRandomQuerySynthesizer() { + } + + public static DorisSelect generateSelect(DorisGlobalState globalState, int nrColumns) { + DorisTables targetTables = globalState.getSchema().getRandomTableNonEmptyTables(); + List targetColumns = targetTables.getColumns(); + DorisNewExpressionGenerator gen = new DorisNewExpressionGenerator(globalState).setColumns(targetColumns); + DorisSelect select = new DorisSelect(); + HashSet columnOfLeafNode = new HashSet<>(); + gen.setColumnOfLeafNode(columnOfLeafNode); + int freeColumns = targetColumns.size(); + select.setDistinct(DorisSelect.DorisSelectDistinctType.getRandomWithoutNull()); + List columns = new ArrayList<>(); + for (int i = 0; i < nrColumns; i++) { + DorisExpression column = null; + if (freeColumns > 0 && Randomly.getBoolean()) { + column = new DorisColumnValue(targetColumns.get(freeColumns - 1), null); + freeColumns -= 1; + columnOfLeafNode.add((DorisColumnValue) column); + } else { + column = gen.generateExpression(DorisSchema.DorisDataType.BOOLEAN); + } + columns.add(column); + } + select.setFetchColumns(columns); + List tables = targetTables.getTables(); + List tableList = tables.stream().map(t -> new DorisTableReference(t)) + .collect(Collectors.toList()); + List joins = DorisJoin.getJoins(tableList, globalState); + select.setJoinList(joins.stream().collect(Collectors.toList())); + select.setFromList(tableList.stream().collect(Collectors.toList())); + if (Randomly.getBoolean()) { + select.setHavingClause(gen.generateHavingClause()); + } + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(DorisSchema.DorisDataType.BOOLEAN)); + } + + List noExprColumns = new ArrayList<>(columnOfLeafNode); + + if (Randomly.getBoolean()) { + select.setOrderByClauses(Randomly.nonEmptySubset(noExprColumns)); + } + if (Randomly.getBoolean()) { + select.setGroupByExpressions(noExprColumns); + } + if (Randomly.getBoolean()) { + select.setLimitClause(DorisConstant.createIntConstant(Randomly.getNotCachedInteger(0, Integer.MAX_VALUE))); + } + if (Randomly.getBoolean()) { + select.setOffsetClause(DorisConstant.createIntConstant(Randomly.getNotCachedInteger(0, Integer.MAX_VALUE))); + } + return select; + } + +} diff --git a/src/sqlancer/doris/gen/DorisTableGenerator.java b/src/sqlancer/doris/gen/DorisTableGenerator.java new file mode 100644 index 000000000..74d2f956a --- /dev/null +++ b/src/sqlancer/doris/gen/DorisTableGenerator.java @@ -0,0 +1,112 @@ +package sqlancer.doris.gen; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisCompositeDataType; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public class DorisTableGenerator { + + // private final ExpectedErrors errors = new ExpectedErrors(); + + public static SQLQueryAdapter createRandomTableStatement(DorisGlobalState globalState) throws SQLException { + if (globalState.getSchema().getDatabaseTables().size() > globalState.getDbmsSpecificOptions().maxNumTables) { + throw new IgnoreMeException(); + } + return new DorisTableGenerator().getQuery(globalState); + } + + public SQLQueryAdapter getQuery(DorisGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + String tableName = globalState.getSchema().getFreeTableName(); + DorisSchema.DorisTableDataModel dataModel = DorisSchema.DorisTableDataModel.getRandom(); + sb.append("CREATE TABLE "); + sb.append(tableName); + sb.append("("); + List columns = getNewColumns(globalState); + Collections.sort(columns); + if (columns.isEmpty() || !columns.get(0).isKey()) { + return null; // ensure table has at least one key column + } + sb.append(columns.stream().map(DorisColumn::toString).collect(Collectors.joining(", "))); + sb.append(")"); + + List keysColumn = columns.stream().filter(DorisColumn::isKey).collect(Collectors.toList()); + if (globalState.getDbmsSpecificOptions().testDataModel && Randomly.getBoolean() && !keysColumn.isEmpty()) { + sb.append(" " + dataModel).append(" KEY("); + sb.append(keysColumn.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + sb.append(")"); + } + sb.append(generateDistributionStr(globalState, dataModel, keysColumn)); + sb.append(" PROPERTIES (\"replication_num\" = \"1\")"); // now only consider this one parameter + DorisErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + + public static String generateDistributionStr(DorisGlobalState globalState, + DorisSchema.DorisTableDataModel dataModel, List keysColumn) { + // DISTRIBUTED BY HASH (k1[,k2 ...]) [BUCKETS num] + // DISTRIBUTED BY RANDOM [BUCKETS num] + StringBuilder sb = new StringBuilder(); + sb.append(" DISTRIBUTED BY"); + if (dataModel == DorisSchema.DorisTableDataModel.UNIQUE || Randomly.getBoolean()) { + sb.append(" HASH ("); + sb.append(Randomly.nonEmptySubset(keysColumn).stream().map(DorisColumn::getName) + .collect(Collectors.joining(", "))); + sb.append(")"); + } else { + sb.append(" RANDOM"); + } + if (Randomly.getBoolean()) { + sb.append(" BUCKETS ").append(globalState.getRandomly().getInteger(1, 32)); + } + return sb.toString(); + } + + private static List getNewColumns(DorisGlobalState globalState) { + List columns = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + String columnName = String.format("c%d", i); + DorisCompositeDataType columnType = DorisCompositeDataType.getRandomWithoutNull(); + columnType.initColumnArgs(); // set decimalAndVarchar + + boolean iskey = columnType.canBeKey() && Randomly.getBoolean(); + boolean isNullable = Randomly.getBoolean(); + if (!globalState.getDbmsSpecificOptions().testNotNullConstraints) { + isNullable = true; + } + // boolean isHllOrBitmap = (columnType.getPrimitiveDataType() == DorisSchema.DorisDataType.HLL) + // || (columnType.getPrimitiveDataType() == DorisSchema.DorisDataType.BITMAP); + boolean isHllOrBitmap = false; + DorisSchema.DorisColumnAggrType aggrType = DorisSchema.DorisColumnAggrType.NULL; + if (globalState.getDbmsSpecificOptions().testColumnAggr && (isHllOrBitmap || !iskey)) { + aggrType = DorisSchema.DorisColumnAggrType.getRandom(columnType); + } + + boolean hasDefaultValue = globalState.getDbmsSpecificOptions().testDefaultValues && Randomly.getBoolean() + && !isHllOrBitmap; + String defaultValue = ""; + if (hasDefaultValue) { + defaultValue = DorisToStringVisitor.asString(new DorisNewExpressionGenerator(globalState) + .generateConstant(columnType.getPrimitiveDataType(), isNullable)); + } + columns.add(new DorisColumn(columnName, columnType, iskey, isNullable, aggrType, hasDefaultValue, + defaultValue)); + } + return columns; + } + +} diff --git a/src/sqlancer/doris/gen/DorisUpdateGenerator.java b/src/sqlancer/doris/gen/DorisUpdateGenerator.java new file mode 100644 index 000000000..906173921 --- /dev/null +++ b/src/sqlancer/doris/gen/DorisUpdateGenerator.java @@ -0,0 +1,55 @@ +package sqlancer.doris.gen; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.gen.AbstractUpdateGenerator; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisTable; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public final class DorisUpdateGenerator extends AbstractUpdateGenerator { + + private final DorisGlobalState globalState; + private DorisNewExpressionGenerator gen; + + private DorisUpdateGenerator(DorisGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter getQuery(DorisGlobalState globalState) { + return new DorisUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + DorisTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyColumnSubset(); + gen = new DorisNewExpressionGenerator(globalState).setColumns(table.getColumns()); + sb.append("UPDATE "); + sb.append(table.getName()); + sb.append(" SET "); + updateColumns(columns); + sb.append(" WHERE "); + sb.append(DorisToStringVisitor.asString(gen.generateExpression(DorisSchema.DorisDataType.BOOLEAN))); + DorisErrors.addInsertErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } + + @Override + protected void updateValue(DorisColumn column) { + if (Randomly.getBooleanWithSmallProbability()) { + DorisExpression expr = gen.generateExpression(column.getType().getPrimitiveDataType()); + sb.append(DorisToStringVisitor.asString(expr)); + } else { + DorisExpression expr = gen.generateConstant(column.getType().getPrimitiveDataType(), column.isNullable()); + sb.append(DorisToStringVisitor.asString(expr)); + } + + } + +} diff --git a/src/sqlancer/doris/gen/DorisViewGenerator.java b/src/sqlancer/doris/gen/DorisViewGenerator.java new file mode 100644 index 000000000..480607378 --- /dev/null +++ b/src/sqlancer/doris/gen/DorisViewGenerator.java @@ -0,0 +1,38 @@ +package sqlancer.doris.gen; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public final class DorisViewGenerator { + + private DorisViewGenerator() { + } + + public static SQLQueryAdapter getQuery(DorisGlobalState globalState) { + if (globalState.getSchema().getDatabaseTables().size() > globalState.getDbmsSpecificOptions().maxNumTables) { + throw new IgnoreMeException(); + } + int nrColumns = Randomly.smallNumber() + 1; + StringBuilder sb = new StringBuilder("CREATE VIEW "); + sb.append(globalState.getSchema().getFreeViewName()); + sb.append("("); + for (int i = 0; i < nrColumns; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append("c"); + sb.append(i); + } + sb.append(") AS "); + sb.append(DorisToStringVisitor.asString(DorisRandomQuerySynthesizer.generateSelect(globalState, nrColumns))); + ExpectedErrors errors = new ExpectedErrors(); + DorisErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + +} diff --git a/src/sqlancer/doris/oracle/DorisPivotedQuerySynthesisOracle.java b/src/sqlancer/doris/oracle/DorisPivotedQuerySynthesisOracle.java new file mode 100644 index 000000000..a7e8b5b02 --- /dev/null +++ b/src/sqlancer/doris/oracle/DorisPivotedQuerySynthesisOracle.java @@ -0,0 +1,152 @@ +package sqlancer.doris.oracle; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.oracle.PivotedQuerySynthesisBase; +import sqlancer.common.query.Query; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisDataType; +import sqlancer.doris.DorisSchema.DorisRowValue; +import sqlancer.doris.DorisSchema.DorisTables; +import sqlancer.doris.ast.DorisColumnValue; +import sqlancer.doris.ast.DorisConstant; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.ast.DorisSelect; +import sqlancer.doris.ast.DorisTableReference; +import sqlancer.doris.ast.DorisUnaryPostfixOperation; +import sqlancer.doris.ast.DorisUnaryPrefixOperation; +import sqlancer.doris.gen.DorisNewExpressionGenerator; +import sqlancer.doris.visitor.DorisExpectedValueVisitor; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public class DorisPivotedQuerySynthesisOracle + extends PivotedQuerySynthesisBase { + + private List fetchColumns; + + public DorisPivotedQuerySynthesisOracle(DorisGlobalState globalState) { + super(globalState); + DorisErrors.addExpressionErrors(errors); + DorisErrors.addInsertErrors(errors); + } + + @Override + protected Query getRectifiedQuery() throws Exception { + DorisTables randomTables = globalState.getSchema().getRandomTableNonEmptyAndViewTables(); + List columns = randomTables.getColumns(); + DorisSelect selectStatement = new DorisSelect(); + boolean isDistinct = Randomly.getBoolean(); + selectStatement.setDistinct(isDistinct); + pivotRow = randomTables.getRandomRowValue(globalState.getConnection()); + fetchColumns = columns; + selectStatement.setFetchColumns(fetchColumns.stream() + .map(c -> new DorisColumnValue(getFetchValueAliasedColumn(c), pivotRow.getValues().get(c))) + .collect(Collectors.toList())); + selectStatement.setFromList( + randomTables.getTables().stream().map(t -> new DorisTableReference(t)).collect(Collectors.toList())); + DorisExpression whereClause = generateRectifiedExpression(columns, pivotRow); + selectStatement.setWhereClause(whereClause); + List groupByClause = generateGroupByClause(columns, pivotRow); + selectStatement.setGroupByExpressions(groupByClause); + DorisExpression limitClause = generateLimit(); + selectStatement.setLimitClause(limitClause); + if (limitClause != null) { + DorisExpression offsetClause = generateOffset(); + selectStatement.setOffsetClause(offsetClause); + } + DorisNewExpressionGenerator gen = new DorisNewExpressionGenerator(globalState); + gen.setColumns(columns); + if (!isDistinct) { + List constants = new ArrayList<>(); + constants.add(new DorisConstant.DorisIntConstant( + Randomly.smallNumber() % selectStatement.getFetchColumns().size() + 1)); + selectStatement.setOrderByClauses(constants); + } + return new SQLQueryAdapter(DorisToStringVisitor.asString(selectStatement), errors); + } + + private DorisExpression generateRectifiedExpression(List columns, DorisRowValue pivotRow) { + DorisNewExpressionGenerator gen = new DorisNewExpressionGenerator(globalState).setColumns(columns); + gen.setRowValue(pivotRow); + DorisExpression expr = gen.generateExpressionWithExpectedResult(DorisDataType.BOOLEAN); + DorisExpression result = null; + if (expr.getExpectedValue().isNull()) { + result = new DorisUnaryPostfixOperation(expr, DorisUnaryPostfixOperation.DorisUnaryPostfixOperator.IS_NULL); + } else if (!expr.getExpectedValue().cast(DorisDataType.BOOLEAN).asBoolean()) { + result = new DorisUnaryPrefixOperation(expr, DorisUnaryPrefixOperation.DorisUnaryPrefixOperator.NOT); + } + rectifiedPredicates.add(result); + return result; + } + + @Override + protected Query getContainmentCheckQuery(Query pivotRowQuery) throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append("SELECT * FROM ("); + sb.append(pivotRowQuery.getUnterminatedQueryString()); + sb.append(") as result WHERE "); + int i = 0; + for (DorisColumn c : fetchColumns) { + if (i++ != 0) { + sb.append(" AND "); + } + sb.append("result."); + sb.append(c.getTable().getName()); + sb.append(c.getName()); + if (pivotRow.getValues().get(c).isNull()) { + sb.append(" IS NULL "); + } else { + sb.append(" = "); + sb.append(pivotRow.getValues().get(c).toString()); + } + } + String resultingQueryString = sb.toString(); + return new SQLQueryAdapter(resultingQueryString, errors); + } + + private DorisColumn getFetchValueAliasedColumn(DorisColumn c) { + DorisColumn aliasedColumn = new DorisColumn(c.getName() + " AS " + c.getTable().getName() + c.getName(), + c.getType(), false, false); + aliasedColumn.setTable(c.getTable()); + return aliasedColumn; + } + + @Override + protected String getExpectedValues(DorisExpression expr) { + return DorisExpectedValueVisitor.asExpectedValues(expr); + } + + private List generateGroupByClause(List columns, DorisRowValue rowValue) { + if (Randomly.getBoolean()) { + return columns.stream().map(c -> new DorisColumnValue(c, rowValue.getValues().get(c))) + .collect(Collectors.toList()); + } else { + return Collections.emptyList(); + } + } + + private DorisExpression generateLimit() { + if (Randomly.getBoolean()) { + return DorisConstant.createIntConstant(Integer.MAX_VALUE); + } else { + return null; + } + } + + private DorisExpression generateOffset() { + if (Randomly.getBoolean()) { + return DorisConstant.createIntConstant(0); + } else { + return null; + } + } + +} diff --git a/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningAggregateTester.java b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningAggregateTester.java new file mode 100644 index 000000000..afc1dde43 --- /dev/null +++ b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningAggregateTester.java @@ -0,0 +1,202 @@ +package sqlancer.doris.oracle.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema.DorisCompositeDataType; +import sqlancer.doris.DorisSchema.DorisDataType; +import sqlancer.doris.ast.DorisAggregateOperation; +import sqlancer.doris.ast.DorisAggregateOperation.DorisAggregateFunction; +import sqlancer.doris.ast.DorisAlias; +import sqlancer.doris.ast.DorisBinaryArithmeticOperation; +import sqlancer.doris.ast.DorisBinaryOperation; +import sqlancer.doris.ast.DorisCastOperation; +import sqlancer.doris.ast.DorisConstant; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.ast.DorisFunction; +import sqlancer.doris.ast.DorisSelect; +import sqlancer.doris.ast.DorisUnaryPostfixOperation; +import sqlancer.doris.ast.DorisUnaryPostfixOperation.DorisUnaryPostfixOperator; +import sqlancer.doris.ast.DorisUnaryPrefixOperation; +import sqlancer.doris.ast.DorisUnaryPrefixOperation.DorisUnaryPrefixOperator; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public class DorisQueryPartitioningAggregateTester extends DorisQueryPartitioningBase + implements TestOracle { + + private String firstResult; + private String secondResult; + private String originalQuery; + private String metamorphicQuery; + + public DorisQueryPartitioningAggregateTester(DorisGlobalState state) { + super(state); + DorisErrors.addExpressionErrors(errors); + DorisErrors.addInsertErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + DorisAggregateFunction aggregateFunction = Randomly.fromOptions(DorisAggregateFunction.MAX, + DorisAggregateFunction.MIN, DorisAggregateFunction.SUM, DorisAggregateFunction.COUNT, + DorisAggregateFunction.AVG); + DorisFunction aggregate = (DorisAggregateOperation) gen + .generateArgsForAggregate(aggregateFunction); + List fetchColumns = new ArrayList<>(); + fetchColumns.add(aggregate); + while (Randomly.getBooleanWithRatherLowProbability()) { + fetchColumns.add((DorisAggregateOperation) gen.generateAggregate()); + } + select.setFetchColumns(Arrays.asList(aggregate)); + if (Randomly.getBooleanWithRatherLowProbability()) { + List constants = new ArrayList<>(); + constants.add( + new DorisConstant.DorisIntConstant(Randomly.smallNumber() % select.getFetchColumns().size() + 1)); + select.setOrderByClauses(constants); + } + originalQuery = DorisToStringVisitor.asString(select); + firstResult = getAggregateResult(originalQuery); + metamorphicQuery = createMetamorphicUnionQuery(select, aggregate, select.getFromList()); + secondResult = getAggregateResult(metamorphicQuery); + + state.getState().getLocalState().log( + "--" + originalQuery + ";\n--" + metamorphicQuery + "\n-- " + firstResult + "\n-- " + secondResult); + if (firstResult == null && secondResult == null) { + return; + } + if (firstResult == null) { + throw new AssertionError(); + } + firstResult = firstResult.replace("\0", ""); + if (firstResult.contentEquals("0") && secondResult == null) { + return; + } + if (secondResult == null) { + throw new AssertionError(); + } + secondResult = secondResult.replace("\0", ""); + if (!firstResult.contentEquals(secondResult) && !ComparatorHelper.isEqualDouble(firstResult, secondResult)) { + throw new AssertionError(); + } + + } + + private String createMetamorphicUnionQuery(DorisSelect select, DorisFunction aggregate, + List from) { + String metamorphicQuery; + DorisExpression whereClause = gen.generateExpression(DorisDataType.BOOLEAN); + DorisExpression negatedClause = new DorisUnaryPrefixOperation(whereClause, DorisUnaryPrefixOperator.NOT); + DorisExpression notNullClause = new DorisUnaryPostfixOperation(whereClause, DorisUnaryPostfixOperator.IS_NULL); + List mappedAggregate = mapped(aggregate); + DorisSelect leftSelect = getSelect(mappedAggregate, from, whereClause, select.getJoinList()); + DorisSelect middleSelect = getSelect(mappedAggregate, from, negatedClause, select.getJoinList()); + DorisSelect rightSelect = getSelect(mappedAggregate, from, notNullClause, select.getJoinList()); + if (Randomly.getBooleanWithSmallProbability()) { + leftSelect.setGroupByExpressions(groupByExpression); + middleSelect.setGroupByExpressions(groupByExpression); + rightSelect.setGroupByExpressions(groupByExpression); + } + metamorphicQuery = "SELECT " + getOuterAggregateFunction(aggregate) + " FROM ("; + metamorphicQuery += DorisToStringVisitor.asString(leftSelect) + " UNION ALL " + + DorisToStringVisitor.asString(middleSelect) + " UNION ALL " + + DorisToStringVisitor.asString(rightSelect); + metamorphicQuery += ") as asdf"; + return metamorphicQuery; + } + + private String getAggregateResult(String queryString) throws SQLException { + String resultString; + SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors); + try (SQLancerResultSet result = q.executeAndGet(state)) { + if (result == null) { + throw new IgnoreMeException(); + } + if (!result.next()) { + resultString = null; + } else { + resultString = result.getString(1); + } + return resultString; + } catch (SQLException e) { + if (!e.getMessage().contains("Not implemented type")) { + throw new AssertionError(queryString, e); + } else { + throw new IgnoreMeException(); + } + } + } + + private List mapped(DorisFunction aggregate) { + + DorisCastOperation count; + switch (aggregate.getFunc()) { + case COUNT: + case MAX: + case MIN: + case SUM: + return aliasArgs(Arrays.asList(aggregate)); + case AVG: + DorisFunction sum = new DorisFunction<>(aggregate.getArgs(), + DorisAggregateFunction.SUM); + count = new DorisCastOperation(new DorisFunction<>(aggregate.getArgs(), DorisAggregateFunction.COUNT), + new DorisCompositeDataType(DorisDataType.FLOAT, 8)); + return aliasArgs(Arrays.asList(sum, count)); + case STDDEV_POP: + DorisFunction sumSquared = new DorisFunction<>( + Arrays.asList(new DorisBinaryOperation(aggregate.getArgs().get(0), aggregate.getArgs().get(0), + DorisBinaryArithmeticOperation.DorisBinaryArithmeticOperator.MULTIPLICATION)), + DorisAggregateFunction.SUM); + count = new DorisCastOperation(new DorisFunction<>(aggregate.getArgs(), DorisAggregateFunction.COUNT), + new DorisCompositeDataType(DorisDataType.FLOAT, 8)); + DorisFunction avg = new DorisFunction<>(aggregate.getArgs(), + DorisAggregateFunction.AVG); + return aliasArgs(Arrays.asList(sumSquared, count, avg)); + default: + throw new AssertionError(aggregate.getFunc()); + } + } + + private List aliasArgs(List originalAggregateArgs) { + List args = new ArrayList<>(); + int i = 0; + for (DorisExpression expr : originalAggregateArgs) { + args.add(new DorisAlias(expr, "agg" + i++)); + } + return args; + } + + private String getOuterAggregateFunction(DorisFunction aggregate) { + switch (aggregate.getFunc()) { + case STDDEV_POP: + return "sqrt(SUM(agg0)/SUM(agg1)-SUM(agg2)*SUM(agg2))"; + case AVG: + return "SUM(agg0::FLOAT)/SUM(agg1)::FLOAT"; + case COUNT: + return DorisAggregateFunction.SUM.toString() + "(agg0)"; + default: + return aggregate.getFunc().toString() + "(agg0)"; + } + } + + private DorisSelect getSelect(List aggregates, List from, + DorisExpression whereClause, List joinList) { + DorisSelect leftSelect = new DorisSelect(); + leftSelect.setFetchColumns(aggregates); + leftSelect.setFromList(from); + leftSelect.setWhereClause(whereClause); + leftSelect.setJoinList(joinList); + return leftSelect; + } + +} diff --git a/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningBase.java b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningBase.java new file mode 100644 index 000000000..553e7739a --- /dev/null +++ b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningBase.java @@ -0,0 +1,82 @@ +package sqlancer.doris.oracle.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema; +import sqlancer.doris.DorisSchema.DorisColumn; +import sqlancer.doris.DorisSchema.DorisTable; +import sqlancer.doris.DorisSchema.DorisTables; +import sqlancer.doris.ast.DorisColumnReference; +import sqlancer.doris.ast.DorisColumnValue; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.ast.DorisJoin; +import sqlancer.doris.ast.DorisSelect; +import sqlancer.doris.ast.DorisTableReference; +import sqlancer.doris.gen.DorisNewExpressionGenerator; + +public class DorisQueryPartitioningBase extends TernaryLogicPartitioningOracleBase + implements TestOracle { + + DorisSchema s; + DorisTables targetTables; + DorisNewExpressionGenerator gen; + DorisSelect select; + + List groupByExpression; + + public DorisQueryPartitioningBase(DorisGlobalState state) { + super(state); + DorisErrors.addExpressionErrors(errors); + DorisErrors.addInsertErrors(errors); + } + + @Override + public void check() throws SQLException { + s = state.getSchema(); + targetTables = s.getRandomTableNonEmptyTables(); + gen = new DorisNewExpressionGenerator(state).setColumns(targetTables.getColumns()); + List allColumnValues = targetTables.getColumns().stream() + .map(c -> new DorisColumnValue(c, null)).collect(Collectors.toList()); + HashSet columnOfLeafNode = new HashSet<>(); + gen.setColumnOfLeafNode(columnOfLeafNode); + initializeTernaryPredicateVariants(); + select = new DorisSelect(); + columnOfLeafNode.addAll(allColumnValues); + groupByExpression = new ArrayList<>(allColumnValues); + select.setFetchColumns(generateFetchColumns()); + List tables = targetTables.getTables(); + List tableList = tables.stream().map(t -> new DorisTableReference(t)) + .collect(Collectors.toList()); + List joins = DorisJoin.getJoins(tableList, state); + select.setJoinList(joins.stream().collect(Collectors.toList())); + select.setFromList(tableList.stream().collect(Collectors.toList())); + select.setWhereClause(null); + } + + List generateFetchColumns() { + List columns = new ArrayList<>(); + if (Randomly.getBoolean()) { + columns.add(new DorisColumnReference(new DorisColumn("*", null, false, false))); + } else { + columns = Randomly.nonEmptySubset(targetTables.getColumns()).stream().map(c -> new DorisColumnReference(c)) + .collect(Collectors.toList()); + } + return columns; + } + + @Override + protected ExpressionGenerator getGen() { + return gen; + } + +} diff --git a/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningDistinctTester.java b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningDistinctTester.java new file mode 100644 index 000000000..dc350ea6f --- /dev/null +++ b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningDistinctTester.java @@ -0,0 +1,44 @@ +package sqlancer.doris.oracle.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public class DorisQueryPartitioningDistinctTester extends DorisQueryPartitioningBase { + + public DorisQueryPartitioningDistinctTester(DorisGlobalState state) { + super(state); + DorisErrors.addExpressionErrors(errors); + DorisErrors.addInsertErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + select.setDistinct(true); + select.setWhereClause(null); + String originalQueryString = DorisToStringVisitor.asString(select); + List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + + select.setWhereClause(predicate); + String firstQueryString = DorisToStringVisitor.asString(select); + select.setWhereClause(negatedPredicate); + String secondQueryString = DorisToStringVisitor.asString(select); + select.setWhereClause(isNullPredicate); + String thirdQueryString = DorisToStringVisitor.asString(select); + List combinedString = new ArrayList<>(); + + String unionString = "SELECT DISTINCT * FROM (" + firstQueryString + " UNION ALL " + secondQueryString + + " UNION ALL " + thirdQueryString + ") tmpTable"; + combinedString.add(unionString); + List secondResultSet = ComparatorHelper.getResultSetFirstColumnAsString(unionString, errors, state); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state, ComparatorHelper::canonicalizeResultValue); + } + +} diff --git a/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningGroupByTester.java b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningGroupByTester.java new file mode 100644 index 000000000..97ade18e3 --- /dev/null +++ b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningGroupByTester.java @@ -0,0 +1,52 @@ +package sqlancer.doris.oracle.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.ast.DorisColumnReference; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public class DorisQueryPartitioningGroupByTester extends DorisQueryPartitioningBase { + + public DorisQueryPartitioningGroupByTester(DorisGlobalState state) { + super(state); + DorisErrors.addExpressionErrors(errors); + DorisErrors.addInsertErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + select.setGroupByExpressions(select.getFetchColumns()); + select.setWhereClause(null); + String originalQueryString = DorisToStringVisitor.asString(select); + + List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + + select.setWhereClause(predicate); + String firstQueryString = DorisToStringVisitor.asString(select); + select.setWhereClause(negatedPredicate); + String secondQueryString = DorisToStringVisitor.asString(select); + select.setWhereClause(isNullPredicate); + String thirdQueryString = DorisToStringVisitor.asString(select); + List combinedString = new ArrayList<>(); + List secondResultSet = ComparatorHelper.getCombinedResultSetNoDuplicates(firstQueryString, + secondQueryString, thirdQueryString, combinedString, true, state, errors); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state, ComparatorHelper::canonicalizeResultValue); + } + + @Override + List generateFetchColumns() { + return Randomly.nonEmptySubset(targetTables.getColumns()).stream().map(c -> new DorisColumnReference(c)) + .collect(Collectors.toList()); + } + +} diff --git a/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningHavingTester.java b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningHavingTester.java new file mode 100644 index 000000000..434ee01d5 --- /dev/null +++ b/src/sqlancer/doris/oracle/tlp/DorisQueryPartitioningHavingTester.java @@ -0,0 +1,71 @@ +package sqlancer.doris.oracle.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.doris.DorisErrors; +import sqlancer.doris.DorisProvider.DorisGlobalState; +import sqlancer.doris.DorisSchema; +import sqlancer.doris.ast.DorisConstant; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.visitor.DorisToStringVisitor; + +public class DorisQueryPartitioningHavingTester extends DorisQueryPartitioningBase + implements TestOracle { + + public DorisQueryPartitioningHavingTester(DorisGlobalState state) { + super(state); + DorisErrors.addExpressionErrors(errors); + DorisErrors.addInsertErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(DorisSchema.DorisDataType.BOOLEAN)); + } + select.setFetchColumns(groupByExpression); + boolean orderBy = Randomly.getBoolean(); + if (orderBy) { + List constants = new ArrayList<>(); + constants.add( + new DorisConstant.DorisIntConstant(Randomly.smallNumber() % select.getFetchColumns().size() + 1)); + select.setOrderByClauses(constants); + } + select.setGroupByExpressions(groupByExpression); + select.setHavingClause(null); + String originalQueryString = DorisToStringVisitor.asString(select); + List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + + select.setHavingClause(predicate); + String firstQueryString = DorisToStringVisitor.asString(select); + select.setHavingClause(negatedPredicate); + String secondQueryString = DorisToStringVisitor.asString(select); + select.setHavingClause(isNullPredicate); + String thirdQueryString = DorisToStringVisitor.asString(select); + List combinedString = new ArrayList<>(); + List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, + thirdQueryString, combinedString, !orderBy, state, errors); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state, ComparatorHelper::canonicalizeResultValue); + } + + @Override + protected DorisExpression generatePredicate() { + return gen.generateHavingClause(); + } + + @Override + List generateFetchColumns() { + gen.setAllowAggregateFunctions(true); + List expressions = gen.generateExpressions(Randomly.smallNumber() + 1); + gen.setAllowAggregateFunctions(false); + return expressions; + } + +} diff --git a/src/sqlancer/doris/utils/DorisNumberUtils.java b/src/sqlancer/doris/utils/DorisNumberUtils.java new file mode 100644 index 000000000..01dd7238d --- /dev/null +++ b/src/sqlancer/doris/utils/DorisNumberUtils.java @@ -0,0 +1,111 @@ +package sqlancer.doris.utils; + +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public final class DorisNumberUtils { + private static Pattern numberPattern = Pattern.compile("-?[0-9]+(\\\\.[0-9]+)?"); + private static Pattern integerPattern = Pattern.compile("^[-\\+]?[\\d]*$"); + private static Pattern datePattern = Pattern + .compile("^([1-9]\\d{3}-)(([0]{0,1}[1-9]-)|([1][0-2]-))(([0-3]{0,1}[0-9]))$"); + private static Pattern datetimePattern = Pattern.compile( + "((([0-9]{3}[1-9]|[0-9]{2}[1-9][0-9]{1}|[0-9]{1}[1-9][0-9]{2}|[1-9][0-9]{3})-(((0[13578]|1[02])-(0[1-9]|[12][0-9]|3[01]))|((0[469]|11)-(0[1-9]|[12][0-9]|30))|(02-(0[1-9]|[1][0-9]|2[0-8]))))|((([0-9]{2})(0[48]|[2468][048]|[13579][26])|((0[48]|[2468][048]|[3579][26])00))-02-29))\\\\s+([0-1]?[0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])\n"); + + private DorisNumberUtils() { + } + + public static boolean isNumber(String str) { + Matcher m = numberPattern.matcher(str); + return m.matches(); + } + + public static boolean isInteger(String str) { + Matcher m = integerPattern.matcher(str); + return m.matches(); + } + + public static boolean isDate(String str) { + Matcher m = datePattern.matcher(str); + return m.matches(); + } + + public static boolean isDatetime(String str) { + Matcher m = datetimePattern.matcher(str); + return m.matches(); + } + + public static String timestampToDateText(long ts) { + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + return dateFormat.format(ts); + } + + public static String timestampToDatetimeText(long ts) { + SimpleDateFormat datetimeFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + return datetimeFormat.format(ts); + } + + public static String dateTextToDatetimeText(String date) { + // '2021-03-12' -> '2021-03-12 00:00:00' + return date + " 00:00:00"; + } + + public static String datetimeTextToDateText(String datetime) { + // '2021-03-12 00:00:00' -> '2021-03-12' + return datetime.substring(0, 10); + } + + public static boolean datetimeEqual(String dt1, String dt2) { + String datetime1 = dt1; + String datetime2 = dt2; + if (isDate(dt1)) { + datetime1 = dateTextToDatetimeText(dt1); + } + if (isDate(dt2)) { + datetime2 = dateTextToDatetimeText(dt2); + } + return datetime1.contentEquals(datetime2); + } + + public static boolean dateEqual(String d1, String d2) { + String date1 = d1; + String date2 = d2; + if (isDatetime(d1)) { + date1 = datetimeTextToDateText(d1); + } + if (isDatetime(d2)) { + date2 = datetimeTextToDateText(d2); + } + return date1.contentEquals(date2); + } + + public static boolean dateLessThan(String d1, String d2) { + String date1 = d1; + String date2 = d2; + if (isDatetime(d1)) { + date1 = datetimeTextToDateText(d1); + } + if (isDatetime(d2)) { + date2 = datetimeTextToDateText(d2); + } + return date1.compareTo(date2) < 0; + } + + public static boolean datetimeLessThan(String dt1, String dt2) { + String datetime1 = dt1; + String datetime2 = dt2; + if (isDate(dt1)) { + datetime1 = dateTextToDatetimeText(dt1); + } + if (isDate(dt2)) { + datetime2 = dateTextToDatetimeText(dt2); + } + return datetime1.compareTo(datetime2) < 0; + } + + public static String getCurrentTimeText() { + SimpleDateFormat datetimeFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + return datetimeFormat.format(new Date()); + } +} diff --git a/src/sqlancer/doris/visitor/DorisExpectedValueVisitor.java b/src/sqlancer/doris/visitor/DorisExpectedValueVisitor.java new file mode 100644 index 000000000..10c90e47e --- /dev/null +++ b/src/sqlancer/doris/visitor/DorisExpectedValueVisitor.java @@ -0,0 +1,151 @@ +package sqlancer.doris.visitor; + +import java.util.List; + +import sqlancer.doris.ast.DorisAlias; +import sqlancer.doris.ast.DorisBetweenOperation; +import sqlancer.doris.ast.DorisBinaryOperation; +import sqlancer.doris.ast.DorisColumnReference; +import sqlancer.doris.ast.DorisConstant; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.ast.DorisFunction; +import sqlancer.doris.ast.DorisInOperation; +import sqlancer.doris.ast.DorisJoin; +import sqlancer.doris.ast.DorisOrderByTerm; +import sqlancer.doris.ast.DorisPostfixText; +import sqlancer.doris.ast.DorisSelect; +import sqlancer.doris.ast.DorisTableReference; +import sqlancer.doris.ast.DorisUnaryPostfixOperation; +import sqlancer.doris.ast.DorisUnaryPrefixOperation; + +public class DorisExpectedValueVisitor { + + protected final StringBuilder sb = new StringBuilder(); + + private void print(DorisExpression expr) { + sb.append(DorisToStringVisitor.asString(expr)); + sb.append(" -- "); + sb.append(((DorisExpression) expr).getExpectedValue()); + sb.append("\n"); + } + + public void visit(DorisExpression expr) { + assert expr != null; + if (expr instanceof DorisColumnReference) { + visit((DorisColumnReference) expr); + } else if (expr instanceof DorisUnaryPostfixOperation) { + visit((DorisUnaryPostfixOperation) expr); + } else if (expr instanceof DorisUnaryPrefixOperation) { + visit((DorisUnaryPrefixOperation) expr); + } else if (expr instanceof DorisBinaryOperation) { + visit((DorisBinaryOperation) expr); + } else if (expr instanceof DorisTableReference) { + visit((DorisTableReference) expr); + } else if (expr instanceof DorisFunction) { + visit((DorisFunction) expr); + } else if (expr instanceof DorisBetweenOperation) { + visit((DorisBetweenOperation) expr); + } else if (expr instanceof DorisInOperation) { + visit((DorisInOperation) expr); + } else if (expr instanceof DorisOrderByTerm) { + visit((DorisOrderByTerm) expr); + } else if (expr instanceof DorisAlias) { + visit((DorisAlias) expr); + } else if (expr instanceof DorisPostfixText) { + visit((DorisPostfixText) expr); + } else if (expr instanceof DorisConstant) { + visit((DorisConstant) expr); + } else if (expr instanceof DorisSelect) { + visit((DorisSelect) expr); + } else if (expr instanceof DorisJoin) { + visit((DorisJoin) expr); + } else { + throw new AssertionError(expr); + } + } + + public void visit(DorisColumnReference c) { + print(c); + } + + public void visit(DorisUnaryPostfixOperation op) { + print(op); + visit(op.getExpr()); + } + + public void visit(DorisUnaryPrefixOperation op) { + print(op); + visit(op.getExpr()); + } + + public void visit(DorisBinaryOperation op) { + visit(op.getLeft()); + visit(op.getRight()); + } + + public void visit(DorisTableReference t) { + print(t); + } + + public void visit(DorisFunction fun) { + print(fun); + visit(fun.getArgs()); + } + + public void visit(List expressions) { + for (DorisExpression expression : expressions) { + visit(expression); + } + } + + public void visit(DorisBetweenOperation op) { + print(op); + visit(op.getLeft()); + visit(op.getMiddle()); + visit(op.getRight()); + } + + public void visit(DorisInOperation op) { + print(op); + visit(op.getLeft()); + visit(op.getRight()); + } + + public void visit(DorisOrderByTerm op) { + print(op); + visit(op.getExpr()); + } + + public void visit(DorisAlias op) { + print(op); + visit(op.getExpr()); + } + + public void visit(DorisPostfixText postFixText) { + print(postFixText); + visit(postFixText.getExpr()); + } + + public void visit(DorisConstant constant) { + print(constant); + } + + public void visit(DorisSelect select) { + print(select.getWhereClause()); + } + + public void visit(DorisJoin join) { + print(join.getOnCondition()); + } + + public String get() { + return sb.toString(); + } + + public static String asExpectedValues(DorisExpression expr) { + DorisExpectedValueVisitor v = new DorisExpectedValueVisitor(); + v.visit(expr); + return v.get(); + } + +} diff --git a/src/sqlancer/doris/visitor/DorisToStringVisitor.java b/src/sqlancer/doris/visitor/DorisToStringVisitor.java new file mode 100644 index 000000000..fcea9ab69 --- /dev/null +++ b/src/sqlancer/doris/visitor/DorisToStringVisitor.java @@ -0,0 +1,168 @@ +package sqlancer.doris.visitor; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.NewToStringVisitor; +import sqlancer.doris.ast.DorisCaseOperation; +import sqlancer.doris.ast.DorisCastOperation; +import sqlancer.doris.ast.DorisConstant; +import sqlancer.doris.ast.DorisExpression; +import sqlancer.doris.ast.DorisFunctionOperation; +import sqlancer.doris.ast.DorisJoin; +import sqlancer.doris.ast.DorisSelect; + +public class DorisToStringVisitor extends NewToStringVisitor { + + @Override + public void visitSpecific(DorisExpression expr) { + if (expr instanceof DorisConstant) { + visit((DorisConstant) expr); + } else if (expr instanceof DorisSelect) { + visit((DorisSelect) expr); + } else if (expr instanceof DorisJoin) { + visit((DorisJoin) expr); + } else if (expr instanceof DorisCastOperation) { + visit((DorisCastOperation) expr); + } else if (expr instanceof DorisCaseOperation) { + visit((DorisCaseOperation) expr); + } else if (expr instanceof DorisFunctionOperation) { + visit((DorisFunctionOperation) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + private void visit(DorisJoin join) { + sb.append(" "); + visit((DorisExpression) join.getLeftTable()); + sb.append(" "); + switch (join.getJoinType()) { + case INNER: + if (Randomly.getBoolean()) { + sb.append("INNER "); + } else { + sb.append("CROSS "); + } + sb.append("JOIN "); + break; + case LEFT: + sb.append("LEFT "); + if (Randomly.getBoolean()) { + sb.append(" OUTER "); + } + sb.append("JOIN "); + break; + case RIGHT: + sb.append("RIGHT "); + if (Randomly.getBoolean()) { + sb.append(" OUTER "); + } + sb.append("JOIN "); + break; + case STRAIGHT: + sb.append("STRAIGHT_JOIN "); + break; + default: + throw new AssertionError(); + } + visit((DorisExpression) join.getRightTable()); + sb.append(" "); + if (join.getOnCondition() != null) { + sb.append("ON "); + visit(join.getOnCondition()); + } + } + + private void visit(DorisConstant constant) { + sb.append(constant.toString()); + } + + private void visit(DorisCastOperation castExpr) { + sb.append("CAST("); + visit(castExpr.getExpr()); + sb.append(" AS "); + sb.append(castExpr.getType().toString()); + sb.append(") "); + } + + private void visit(DorisFunctionOperation func) { + sb.append(func.getFunction().getFunctionName()); + sb.append("("); + + if (func.getArgs() != null) { + for (int i = 0; i < func.getArgs().size(); i++) { + visit(func.getArgs().get(i)); + if (i != func.getArgs().size() - 1) { + sb.append(","); + } + } + } + sb.append(") "); + } + + private void visit(DorisCaseOperation cases) { + sb.append("CASE "); + visit(cases.getExpr()); + sb.append(" "); + for (int i = 0; i < cases.getConditions().size(); i++) { + DorisExpression predicate = cases.getConditions().get(i); + DorisExpression then = cases.getThenClauses().get(i); + sb.append(" WHEN "); + visit(predicate); + sb.append(" THEN "); + visit(then); + sb.append(" "); + } + if (cases.getElseClause() != null) { + sb.append("ELSE "); + visit(cases.getElseClause()); + sb.append(" "); + } + sb.append("END "); + } + + private void visit(DorisSelect select) { + sb.append("SELECT "); + if (select.isDistinct()) { + sb.append("DISTINCT "); + } + visit(select.getFetchColumns()); + sb.append(" FROM "); + visit(select.getFromList()); + if (!select.getFromList().isEmpty() && !select.getJoinList().isEmpty()) { + sb.append(", "); + } + if (!select.getJoinList().isEmpty()) { + visit(select.getJoinList()); + } + if (select.getWhereClause() != null) { + sb.append(" WHERE "); + visit(select.getWhereClause()); + } + if (!select.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(select.getGroupByExpressions()); + } + if (select.getHavingClause() != null) { + sb.append(" HAVING "); + visit(select.getHavingClause()); + } + if (!select.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(select.getOrderByClauses()); + } + if (select.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(select.getLimitClause()); + } + if (select.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(select.getOffsetClause()); + } + } + + public static String asString(DorisExpression expr) { + DorisToStringVisitor visitor = new DorisToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } +} diff --git a/src/sqlancer/duckdb/DuckDBBugs.java b/src/sqlancer/duckdb/DuckDBBugs.java new file mode 100644 index 000000000..c91661cb5 --- /dev/null +++ b/src/sqlancer/duckdb/DuckDBBugs.java @@ -0,0 +1,8 @@ +package sqlancer.duckdb; + +public final class DuckDBBugs { + + private DuckDBBugs() { + } + +} diff --git a/src/sqlancer/duckdb/DuckDBErrors.java b/src/sqlancer/duckdb/DuckDBErrors.java index cb45bb96c..394de69d4 100644 --- a/src/sqlancer/duckdb/DuckDBErrors.java +++ b/src/sqlancer/duckdb/DuckDBErrors.java @@ -1,5 +1,9 @@ package sqlancer.duckdb; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + import sqlancer.common.query.ExpectedErrors; public final class DuckDBErrors { @@ -7,7 +11,9 @@ public final class DuckDBErrors { private DuckDBErrors() { } - public static void addExpressionErrors(ExpectedErrors errors) { + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("with non-constant precision is not supported"); errors.add("Like pattern must not end with escape character"); errors.add("Could not convert string"); @@ -27,9 +33,9 @@ public static void addExpressionErrors(ExpectedErrors errors) { errors.add("GROUP BY clause cannot contain aggregates!"); // investigate - addRegexErrors(errors); + errors.addAll(getRegexErrors()); - addFunctionErrors(errors); + errors.addAll(getFunctionErrors()); errors.add("Overflow in multiplication"); errors.add("Out of Range"); @@ -38,34 +44,49 @@ public static void addExpressionErrors(ExpectedErrors errors) { // collate errors.add("Cannot combine types with different collation!"); errors.add("collations are only supported for type varchar"); + errors.add("COLLATE can only be applied to varchar columns"); - // // https://github.com/cwida/duckdb/issues/532 - errors.add("Not implemented type: DATE"); - errors.add("Not implemented type: TIMESTAMP"); errors.add("Like pattern must not end with escape character!"); // LIKE - errors.add("does not have a column named \"rowid\""); // TODO: this can be removed if we can query whether a - // table supports rowids - errors.add("does not have a column named"); // TODO: this only happens for views whose underlying table has a // removed column errors.add("Contents of view were altered: types don't match!"); errors.add("Not implemented: ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); + errors.add("ORDER BY non-integer literal has no effect"); + + // timestamp + errors.add("Cannot subtract infinite timestamps"); + errors.add("Timestamp difference is out of bounds"); + + return errors; + } + + public static List getExpressionErrorsRegex() { + ArrayList errors = new ArrayList<>(); + + errors.add(Pattern.compile("Binder Error: Cannot mix values of type .* and .* in BETWEEN clause")); + errors.add(Pattern.compile("Binder Error: Cannot mix values of type .* and .* in CASE expression")); + errors.add(Pattern.compile("Cannot mix values of type .* and .* in COALESCE operator")); + errors.add(Pattern.compile("Cannot compare values of type .* and type .*")); + + return errors; } - private static void addRegexErrors(ExpectedErrors errors) { - errors.add("missing ]"); - errors.add("missing )"); - errors.add("invalid escape sequence"); - errors.add("no argument for repetition operator: "); - errors.add("bad repetition operator"); - errors.add("trailing \\"); - errors.add("invalid perl operator"); - errors.add("invalid character class range"); - errors.add("width is not integer"); + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + errors.addAllRegexes(getExpressionErrorsRegex()); + } + + private static List getRegexErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("Invalid Input Error:"); + return errors; } - private static void addFunctionErrors(ExpectedErrors errors) { + private static List getFunctionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("SUBSTRING cannot handle negative lengths"); errors.add("is undefined outside [-1,1]"); // ACOS etc errors.add("invalid type specifier"); // PRINTF @@ -84,15 +105,19 @@ private static void addFunctionErrors(ExpectedErrors errors) { errors.add("Could not choose a best candidate function for the function call"); // monthname errors.add("expected a numeric precision field"); // ROUND errors.add("with non-constant precision is not supported"); // ROUND + + return errors; } - public static void addInsertErrors(ExpectedErrors errors) { - addRegexErrors(errors); - addFunctionErrors(errors); + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + + errors.addAll(getRegexErrors()); + errors.addAll(getFunctionErrors()); errors.add("NOT NULL constraint failed"); - errors.add("PRIMARY KEY or UNIQUE constraint violated"); - errors.add("duplicate key"); + errors.add("PRIMARY KEY or UNIQUE constraint violation"); + errors.add("Duplicate key"); errors.add("can't be cast because the value is out of range for the destination type"); errors.add("Could not convert string"); errors.add("Unimplemented type for cast"); @@ -104,11 +129,26 @@ public static void addInsertErrors(ExpectedErrors errors) { errors.add("Could not cast value"); errors.add("create unique index, table contains duplicate data"); errors.add("Failed to cast"); + + return errors; } - public static void addGroupByErrors(ExpectedErrors errors) { + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); + } + + public static List getGroupByErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("must appear in the GROUP BY clause or must be part of an aggregate function"); errors.add("must appear in the GROUP BY clause or be used in an aggregate function"); errors.add("GROUP BY term out of range"); + + return errors; + } + + public static void addGroupByErrors(ExpectedErrors errors) { + errors.addAll(getGroupByErrors()); } } diff --git a/src/sqlancer/duckdb/DuckDBOptions.java b/src/sqlancer/duckdb/DuckDBOptions.java index 4e1c49cd1..00f85eece 100644 --- a/src/sqlancer/duckdb/DuckDBOptions.java +++ b/src/sqlancer/duckdb/DuckDBOptions.java @@ -1,7 +1,5 @@ package sqlancer.duckdb; -import java.sql.SQLException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -9,17 +7,6 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.CompositeTestOracle; -import sqlancer.common.oracle.TestOracle; -import sqlancer.duckdb.DuckDBOptions.DuckDBOracleFactory; -import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; -import sqlancer.duckdb.test.DuckDBNoRECOracle; -import sqlancer.duckdb.test.DuckDBQueryPartitioningAggregateTester; -import sqlancer.duckdb.test.DuckDBQueryPartitioningDistinctTester; -import sqlancer.duckdb.test.DuckDBQueryPartitioningGroupByTester; -import sqlancer.duckdb.test.DuckDBQueryPartitioningHavingTester; -import sqlancer.duckdb.test.DuckDBQueryPartitioningWhereTester; @Parameters(commandDescription = "DuckDB") public class DuckDBOptions implements DBMSSpecificOptions { @@ -93,62 +80,6 @@ public class DuckDBOptions implements DBMSSpecificOptions { @Parameter(names = "--oracle") public List oracles = Arrays.asList(DuckDBOracleFactory.QUERY_PARTITIONING); - public enum DuckDBOracleFactory implements OracleFactory { - NOREC { - - @Override - public TestOracle create(DuckDBGlobalState globalState) throws SQLException { - return new DuckDBNoRECOracle(globalState); - } - - }, - HAVING { - @Override - public TestOracle create(DuckDBGlobalState globalState) throws SQLException { - return new DuckDBQueryPartitioningHavingTester(globalState); - } - }, - WHERE { - @Override - public TestOracle create(DuckDBGlobalState globalState) throws SQLException { - return new DuckDBQueryPartitioningWhereTester(globalState); - } - }, - GROUP_BY { - @Override - public TestOracle create(DuckDBGlobalState globalState) throws SQLException { - return new DuckDBQueryPartitioningGroupByTester(globalState); - } - }, - AGGREGATE { - - @Override - public TestOracle create(DuckDBGlobalState globalState) throws SQLException { - return new DuckDBQueryPartitioningAggregateTester(globalState); - } - - }, - DISTINCT { - @Override - public TestOracle create(DuckDBGlobalState globalState) throws SQLException { - return new DuckDBQueryPartitioningDistinctTester(globalState); - } - }, - QUERY_PARTITIONING { - @Override - public TestOracle create(DuckDBGlobalState globalState) throws SQLException { - List oracles = new ArrayList<>(); - oracles.add(new DuckDBQueryPartitioningWhereTester(globalState)); - oracles.add(new DuckDBQueryPartitioningHavingTester(globalState)); - oracles.add(new DuckDBQueryPartitioningAggregateTester(globalState)); - oracles.add(new DuckDBQueryPartitioningDistinctTester(globalState)); - oracles.add(new DuckDBQueryPartitioningGroupByTester(globalState)); - return new CompositeTestOracle(oracles, globalState); - } - }; - - } - @Override public List getTestOracleFactory() { return oracles; diff --git a/src/sqlancer/duckdb/DuckDBOracleFactory.java b/src/sqlancer/duckdb/DuckDBOracleFactory.java new file mode 100644 index 000000000..8fd0f96af --- /dev/null +++ b/src/sqlancer/duckdb/DuckDBOracleFactory.java @@ -0,0 +1,86 @@ +package sqlancer.duckdb; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.duckdb.gen.DuckDBExpressionGenerator; +import sqlancer.duckdb.test.DuckDBQueryPartitioningAggregateTester; +import sqlancer.duckdb.test.DuckDBQueryPartitioningDistinctTester; +import sqlancer.duckdb.test.DuckDBQueryPartitioningGroupByTester; +import sqlancer.duckdb.test.DuckDBQueryPartitioningHavingTester; + +public enum DuckDBOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create(DuckDBProvider.DuckDBGlobalState globalState) + throws SQLException { + DuckDBExpressionGenerator gen = new DuckDBExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(DuckDBErrors.getExpressionErrors()) + .withRegex(DuckDBErrors.getExpressionErrorsRegex()) + .with("canceling statement due to statement timeout").build(); + return new NoRECOracle<>(globalState, gen, errors); + } + + }, + HAVING { + @Override + public TestOracle create(DuckDBProvider.DuckDBGlobalState globalState) + throws SQLException { + return new DuckDBQueryPartitioningHavingTester(globalState); + } + }, + WHERE { + @Override + public TestOracle create(DuckDBProvider.DuckDBGlobalState globalState) + throws SQLException { + DuckDBExpressionGenerator gen = new DuckDBExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(DuckDBErrors.getExpressionErrors()) + .with(DuckDBErrors.getGroupByErrors()).withRegex(DuckDBErrors.getExpressionErrorsRegex()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + GROUP_BY { + @Override + public TestOracle create(DuckDBProvider.DuckDBGlobalState globalState) + throws SQLException { + return new DuckDBQueryPartitioningGroupByTester(globalState); + } + }, + AGGREGATE { + @Override + public TestOracle create(DuckDBProvider.DuckDBGlobalState globalState) + throws SQLException { + return new DuckDBQueryPartitioningAggregateTester(globalState); + } + + }, + DISTINCT { + @Override + public TestOracle create(DuckDBProvider.DuckDBGlobalState globalState) + throws SQLException { + return new DuckDBQueryPartitioningDistinctTester(globalState); + } + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(DuckDBProvider.DuckDBGlobalState globalState) + throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(WHERE.create(globalState)); + oracles.add(HAVING.create(globalState)); + oracles.add(AGGREGATE.create(globalState)); + oracles.add(DISTINCT.create(globalState)); + oracles.add(GROUP_BY.create(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + }; + +} diff --git a/src/sqlancer/duckdb/DuckDBProvider.java b/src/sqlancer/duckdb/DuckDBProvider.java index 3164b7d6a..59b26eac7 100644 --- a/src/sqlancer/duckdb/DuckDBProvider.java +++ b/src/sqlancer/duckdb/DuckDBProvider.java @@ -11,6 +11,7 @@ import sqlancer.AbstractAction; import sqlancer.DatabaseProvider; import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; import sqlancer.Randomly; import sqlancer.SQLConnection; import sqlancer.SQLGlobalState; @@ -143,8 +144,12 @@ public SQLConnection createDatabase(DuckDBGlobalState globalState) throws SQLExc String url = "jdbc:duckdb:" + databaseFile; tryDeleteDatabase(databaseFile); - Connection conn = DriverManager.getConnection(url, globalState.getOptions().getUserName(), - globalState.getOptions().getPassword()); + MainOptions options = globalState.getOptions(); + if (!(options.isDefaultUsername() && options.isDefaultPassword())) { + throw new AssertionError("DuckDB doesn't support credentials (username/password)"); + } + + Connection conn = DriverManager.getConnection(url); Statement stmt = conn.createStatement(); stmt.execute("PRAGMA checkpoint_threshold='1 byte';"); stmt.close(); diff --git a/src/sqlancer/duckdb/DuckDBSchema.java b/src/sqlancer/duckdb/DuckDBSchema.java index 6b94ce137..857b1e008 100644 --- a/src/sqlancer/duckdb/DuckDBSchema.java +++ b/src/sqlancer/duckdb/DuckDBSchema.java @@ -4,7 +4,6 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import sqlancer.IgnoreMeException; @@ -133,6 +132,7 @@ public DuckDBColumn(String name, DuckDBCompositeDataType columnType, boolean isP this.isNullable = isNullable; } + @Override public boolean isPrimaryKey() { return isPrimaryKey; } @@ -203,7 +203,7 @@ private static DuckDBCompositeDataType getColumnType(String typeString) { case "TIMESTAMP": primitiveType = DuckDBDataType.TIMESTAMP; break; - case "NULL": + case "\"NULL\"": primitiveType = DuckDBDataType.NULL; break; case "INTERVAL": @@ -218,8 +218,8 @@ private static DuckDBCompositeDataType getColumnType(String typeString) { public static class DuckDBTable extends AbstractRelationalTable { - public DuckDBTable(String tableName, List columns, boolean isView) { - super(tableName, columns, Collections.emptyList(), isView); + public DuckDBTable(String tableName, List columns, List indexes, boolean isView) { + super(tableName, columns, indexes, isView); } } @@ -232,8 +232,9 @@ public static DuckDBSchema fromConnection(SQLConnection con, String databaseName continue; // TODO: unexpected? } List databaseColumns = getTableColumns(con, tableName); - boolean isView = tableName.startsWith("v"); - DuckDBTable t = new DuckDBTable(tableName, databaseColumns, isView); + boolean isView = matchesViewName(tableName); + List indexes = getIndexes(con, tableName); + DuckDBTable t = new DuckDBTable(tableName, databaseColumns, indexes, isView); for (DuckDBColumn c : databaseColumns) { c.setTable(t); } @@ -243,6 +244,21 @@ public static DuckDBSchema fromConnection(SQLConnection con, String databaseName return new DuckDBSchema(databaseTables); } + private static List getIndexes(SQLConnection con, String tableName) throws SQLException { + List indexes = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String.format( + "SELECT index_name FROM duckdb_indexes() WHERE database_name = current_database() AND table_name = '%s';", + tableName))) { + while (rs.next()) { + String indexName = rs.getString("INDEX_NAME"); + indexes.add(TableIndex.create(indexName)); + } + } + } + return indexes; + } + private static List getTableNames(SQLConnection con) throws SQLException { List tableNames = new ArrayList<>(); try (Statement s = con.createStatement()) { @@ -269,7 +285,7 @@ private static List getTableColumns(SQLConnection con, String tabl } } } - if (columns.stream().noneMatch(c -> c.isPrimaryKey())) { + if (columns.stream().noneMatch(c -> c.isPrimaryKey()) && !AbstractSchema.matchesViewName(tableName)) { // https://github.com/cwida/duckdb/issues/589 // https://github.com/cwida/duckdb/issues/588 // TODO: implement an option to enable/disable rowids diff --git a/src/sqlancer/duckdb/DuckDBToStringVisitor.java b/src/sqlancer/duckdb/DuckDBToStringVisitor.java index f7c97f3e5..b1a8eb703 100644 --- a/src/sqlancer/duckdb/DuckDBToStringVisitor.java +++ b/src/sqlancer/duckdb/DuckDBToStringVisitor.java @@ -1,7 +1,7 @@ package sqlancer.duckdb; import sqlancer.common.ast.newast.NewToStringVisitor; -import sqlancer.common.ast.newast.Node; +import sqlancer.common.ast.newast.TableReferenceNode; import sqlancer.duckdb.ast.DuckDBConstant; import sqlancer.duckdb.ast.DuckDBExpression; import sqlancer.duckdb.ast.DuckDBJoin; @@ -10,7 +10,7 @@ public class DuckDBToStringVisitor extends NewToStringVisitor { @Override - public void visitSpecific(Node expr) { + public void visitSpecific(DuckDBExpression expr) { if (expr instanceof DuckDBConstant) { visit((DuckDBConstant) expr); } else if (expr instanceof DuckDBSelect) { @@ -23,7 +23,7 @@ public void visitSpecific(Node expr) { } private void visit(DuckDBJoin join) { - visit(join.getLeftTable()); + visit((TableReferenceNode) join.getLeftTable()); sb.append(" "); sb.append(join.getJoinType()); sb.append(" "); @@ -31,7 +31,7 @@ private void visit(DuckDBJoin join) { sb.append(join.getOuterType()); } sb.append(" JOIN "); - visit(join.getRightTable()); + visit((TableReferenceNode) join.getRightTable()); if (join.getOnCondition() != null) { sb.append(" ON "); visit(join.getOnCondition()); @@ -68,9 +68,9 @@ private void visit(DuckDBSelect select) { sb.append(" HAVING "); visit(select.getHavingClause()); } - if (!select.getOrderByExpressions().isEmpty()) { + if (!select.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - visit(select.getOrderByExpressions()); + visit(select.getOrderByClauses()); } if (select.getLimitClause() != null) { sb.append(" LIMIT "); @@ -82,7 +82,7 @@ private void visit(DuckDBSelect select) { } } - public static String asString(Node expr) { + public static String asString(DuckDBExpression expr) { DuckDBToStringVisitor visitor = new DuckDBToStringVisitor(); visitor.visit(expr); return visitor.get(); diff --git a/src/sqlancer/duckdb/ast/DuckDBAlias.java b/src/sqlancer/duckdb/ast/DuckDBAlias.java new file mode 100644 index 000000000..cc1a91921 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBAlias.java @@ -0,0 +1,9 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.newast.NewAliasNode; + +public class DuckDBAlias extends NewAliasNode implements DuckDBExpression { + public DuckDBAlias(DuckDBExpression expr, String string) { + super(expr, string); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBBetweenOperator.java b/src/sqlancer/duckdb/ast/DuckDBBetweenOperator.java new file mode 100644 index 000000000..eabaa6162 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBBetweenOperator.java @@ -0,0 +1,10 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.newast.NewBetweenOperatorNode; + +public class DuckDBBetweenOperator extends NewBetweenOperatorNode implements DuckDBExpression { + public DuckDBBetweenOperator(DuckDBExpression left, DuckDBExpression middle, DuckDBExpression right, + boolean isTrue) { + super(left, middle, right, isTrue); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBBinaryOperator.java b/src/sqlancer/duckdb/ast/DuckDBBinaryOperator.java new file mode 100644 index 000000000..24492c145 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBBinaryOperator.java @@ -0,0 +1,10 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class DuckDBBinaryOperator extends NewBinaryOperatorNode implements DuckDBExpression { + public DuckDBBinaryOperator(DuckDBExpression left, DuckDBExpression right, Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBCaseOperator.java b/src/sqlancer/duckdb/ast/DuckDBCaseOperator.java new file mode 100644 index 000000000..ba1d9d96f --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBCaseOperator.java @@ -0,0 +1,12 @@ +package sqlancer.duckdb.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewCaseOperatorNode; + +public class DuckDBCaseOperator extends NewCaseOperatorNode implements DuckDBExpression { + public DuckDBCaseOperator(DuckDBExpression switchCondition, List conditions, + List expressions, DuckDBExpression elseExpr) { + super(switchCondition, conditions, expressions, elseExpr); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBColumnReference.java b/src/sqlancer/duckdb/ast/DuckDBColumnReference.java new file mode 100644 index 000000000..d24382777 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBColumnReference.java @@ -0,0 +1,12 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.duckdb.DuckDBSchema; + +public class DuckDBColumnReference extends ColumnReferenceNode + implements DuckDBExpression { + public DuckDBColumnReference(DuckDBSchema.DuckDBColumn column) { + super(column); + } + +} diff --git a/src/sqlancer/duckdb/ast/DuckDBConstant.java b/src/sqlancer/duckdb/ast/DuckDBConstant.java index b895a18e5..f4af6918a 100644 --- a/src/sqlancer/duckdb/ast/DuckDBConstant.java +++ b/src/sqlancer/duckdb/ast/DuckDBConstant.java @@ -3,9 +3,7 @@ import java.sql.Timestamp; import java.text.SimpleDateFormat; -import sqlancer.common.ast.newast.Node; - -public class DuckDBConstant implements Node { +public class DuckDBConstant implements DuckDBExpression { private DuckDBConstant() { } @@ -161,31 +159,31 @@ public String toString() { } - public static Node createStringConstant(String text) { + public static DuckDBExpression createStringConstant(String text) { return new DuckDBTextConstant(text); } - public static Node createFloatConstant(double val) { + public static DuckDBExpression createFloatConstant(double val) { return new DuckDBDoubleConstant(val); } - public static Node createIntConstant(long val) { + public static DuckDBExpression createIntConstant(long val) { return new DuckDBIntConstant(val); } - public static Node createNullConstant() { + public static DuckDBExpression createNullConstant() { return new DuckDBNullConstant(); } - public static Node createBooleanConstant(boolean val) { + public static DuckDBExpression createBooleanConstant(boolean val) { return new DuckDBBooleanConstant(val); } - public static Node createDateConstant(long integer) { + public static DuckDBExpression createDateConstant(long integer) { return new DuckDBDateConstant(integer); } - public static Node createTimestampConstant(long integer) { + public static DuckDBExpression createTimestampConstant(long integer) { return new DuckDBTimestampConstant(integer); } diff --git a/src/sqlancer/duckdb/ast/DuckDBExpression.java b/src/sqlancer/duckdb/ast/DuckDBExpression.java index 3ff66ce33..21d4e6eb5 100644 --- a/src/sqlancer/duckdb/ast/DuckDBExpression.java +++ b/src/sqlancer/duckdb/ast/DuckDBExpression.java @@ -1,5 +1,7 @@ package sqlancer.duckdb.ast; -public interface DuckDBExpression { +import sqlancer.common.ast.newast.Expression; +import sqlancer.duckdb.DuckDBSchema.DuckDBColumn; +public interface DuckDBExpression extends Expression { } diff --git a/src/sqlancer/duckdb/ast/DuckDBFunction.java b/src/sqlancer/duckdb/ast/DuckDBFunction.java new file mode 100644 index 000000000..445583090 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBFunction.java @@ -0,0 +1,11 @@ +package sqlancer.duckdb.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class DuckDBFunction extends NewFunctionNode implements DuckDBExpression { + public DuckDBFunction(List args, F func) { + super(args, func); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBInOperator.java b/src/sqlancer/duckdb/ast/DuckDBInOperator.java new file mode 100644 index 000000000..619601749 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBInOperator.java @@ -0,0 +1,11 @@ +package sqlancer.duckdb.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; + +public class DuckDBInOperator extends NewInOperatorNode implements DuckDBExpression { + public DuckDBInOperator(DuckDBExpression left, List right, boolean isNegated) { + super(left, right, isNegated); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBJoin.java b/src/sqlancer/duckdb/ast/DuckDBJoin.java index c1c0c94a9..f677e9efc 100644 --- a/src/sqlancer/duckdb/ast/DuckDBJoin.java +++ b/src/sqlancer/duckdb/ast/DuckDBJoin.java @@ -4,19 +4,18 @@ import java.util.List; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.common.ast.newast.Join; import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; import sqlancer.duckdb.DuckDBSchema.DuckDBColumn; import sqlancer.duckdb.DuckDBSchema.DuckDBTable; import sqlancer.duckdb.gen.DuckDBExpressionGenerator; -public class DuckDBJoin implements Node { +public class DuckDBJoin implements DuckDBExpression, Join { - private final TableReferenceNode leftTable; - private final TableReferenceNode rightTable; + private final DuckDBTableReference leftTable; + private final DuckDBTableReference rightTable; private final JoinType joinType; - private final Node onCondition; + private DuckDBExpression onCondition; private OuterType outerType; public enum JoinType { @@ -35,20 +34,19 @@ public static OuterType getRandom() { } } - public DuckDBJoin(TableReferenceNode leftTable, - TableReferenceNode rightTable, JoinType joinType, - Node whereCondition) { + public DuckDBJoin(DuckDBTableReference leftTable, DuckDBTableReference rightTable, JoinType joinType, + DuckDBExpression whereCondition) { this.leftTable = leftTable; this.rightTable = rightTable; this.joinType = joinType; this.onCondition = whereCondition; } - public TableReferenceNode getLeftTable() { + public DuckDBTableReference getLeftTable() { return leftTable; } - public TableReferenceNode getRightTable() { + public DuckDBTableReference getRightTable() { return rightTable; } @@ -56,7 +54,7 @@ public JoinType getJoinType() { return joinType; } - public Node getOnCondition() { + public DuckDBExpression getOnCondition() { return onCondition; } @@ -68,12 +66,11 @@ public OuterType getOuterType() { return outerType; } - public static List> getJoins( - List> tableList, DuckDBGlobalState globalState) { - List> joinExpressions = new ArrayList<>(); + public static List getJoins(List tableList, DuckDBGlobalState globalState) { + List joinExpressions = new ArrayList<>(); while (tableList.size() >= 2 && Randomly.getBooleanWithRatherLowProbability()) { - TableReferenceNode leftTable = tableList.remove(0); - TableReferenceNode rightTable = tableList.remove(0); + DuckDBTableReference leftTable = tableList.remove(0); + DuckDBTableReference rightTable = tableList.remove(0); List columns = new ArrayList<>(leftTable.getTable().getColumns()); columns.addAll(rightTable.getTable().getColumns()); DuckDBExpressionGenerator joinGen = new DuckDBExpressionGenerator(globalState).setColumns(columns); @@ -99,26 +96,30 @@ public static List> getJoins( return joinExpressions; } - public static DuckDBJoin createRightOuterJoin(TableReferenceNode left, - TableReferenceNode right, Node predicate) { + public static DuckDBJoin createRightOuterJoin(DuckDBTableReference left, DuckDBTableReference right, + DuckDBExpression predicate) { return new DuckDBJoin(left, right, JoinType.RIGHT, predicate); } - public static DuckDBJoin createLeftOuterJoin(TableReferenceNode left, - TableReferenceNode right, Node predicate) { + public static DuckDBJoin createLeftOuterJoin(DuckDBTableReference left, DuckDBTableReference right, + DuckDBExpression predicate) { return new DuckDBJoin(left, right, JoinType.LEFT, predicate); } - public static DuckDBJoin createInnerJoin(TableReferenceNode left, - TableReferenceNode right, Node predicate) { + public static DuckDBJoin createInnerJoin(DuckDBTableReference left, DuckDBTableReference right, + DuckDBExpression predicate) { return new DuckDBJoin(left, right, JoinType.INNER, predicate); } - public static Node createNaturalJoin(TableReferenceNode left, - TableReferenceNode right, OuterType naturalJoinType) { + public static DuckDBJoin createNaturalJoin(DuckDBTableReference left, DuckDBTableReference right, + OuterType naturalJoinType) { DuckDBJoin join = new DuckDBJoin(left, right, JoinType.NATURAL, null); join.setOuterType(naturalJoinType); return join; } + @Override + public void setOnClause(DuckDBExpression onClause) { + this.onCondition = onClause; + } } diff --git a/src/sqlancer/duckdb/ast/DuckDBOrderingTerm.java b/src/sqlancer/duckdb/ast/DuckDBOrderingTerm.java new file mode 100644 index 000000000..8d7177655 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBOrderingTerm.java @@ -0,0 +1,9 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.newast.NewOrderingTerm; + +public class DuckDBOrderingTerm extends NewOrderingTerm implements DuckDBExpression { + public DuckDBOrderingTerm(DuckDBExpression expr, Ordering ordering) { + super(expr, ordering); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBPostFixText.java b/src/sqlancer/duckdb/ast/DuckDBPostFixText.java new file mode 100644 index 000000000..a877a4c92 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBPostFixText.java @@ -0,0 +1,9 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.newast.NewPostfixTextNode; + +public class DuckDBPostFixText extends NewPostfixTextNode implements DuckDBExpression { + public DuckDBPostFixText(DuckDBExpression expr, String string) { + super(expr, string); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBSelect.java b/src/sqlancer/duckdb/ast/DuckDBSelect.java index 9a2391aa3..e18e57a4d 100644 --- a/src/sqlancer/duckdb/ast/DuckDBSelect.java +++ b/src/sqlancer/duckdb/ast/DuckDBSelect.java @@ -1,9 +1,16 @@ package sqlancer.duckdb.ast; +import java.util.List; +import java.util.stream.Collectors; + import sqlancer.common.ast.SelectBase; -import sqlancer.common.ast.newast.Node; +import sqlancer.common.ast.newast.Select; +import sqlancer.duckdb.DuckDBSchema.DuckDBColumn; +import sqlancer.duckdb.DuckDBSchema.DuckDBTable; +import sqlancer.duckdb.DuckDBToStringVisitor; -public class DuckDBSelect extends SelectBase> implements Node { +public class DuckDBSelect extends SelectBase + implements Select, DuckDBExpression { private boolean isDistinct; @@ -15,4 +22,20 @@ public boolean isDistinct() { return isDistinct; } + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (DuckDBExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (DuckDBJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return DuckDBToStringVisitor.asString(this); + } } diff --git a/src/sqlancer/duckdb/ast/DuckDBTableReference.java b/src/sqlancer/duckdb/ast/DuckDBTableReference.java new file mode 100644 index 000000000..0a8d795c8 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBTableReference.java @@ -0,0 +1,11 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.duckdb.DuckDBSchema; + +public class DuckDBTableReference extends TableReferenceNode + implements DuckDBExpression { + public DuckDBTableReference(DuckDBSchema.DuckDBTable table) { + super(table); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBTernary.java b/src/sqlancer/duckdb/ast/DuckDBTernary.java new file mode 100644 index 000000000..921f77b90 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBTernary.java @@ -0,0 +1,10 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.newast.NewTernaryNode; + +public class DuckDBTernary extends NewTernaryNode implements DuckDBExpression { + public DuckDBTernary(DuckDBExpression left, DuckDBExpression middle, DuckDBExpression right, String leftString, + String rightString) { + super(left, middle, right, leftString, rightString); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBUnaryPostfixOperator.java b/src/sqlancer/duckdb/ast/DuckDBUnaryPostfixOperator.java new file mode 100644 index 000000000..856c129ec --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBUnaryPostfixOperator.java @@ -0,0 +1,11 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; + +public class DuckDBUnaryPostfixOperator extends NewUnaryPostfixOperatorNode + implements DuckDBExpression { + public DuckDBUnaryPostfixOperator(DuckDBExpression expr, BinaryOperatorNode.Operator op) { + super(expr, op); + } +} diff --git a/src/sqlancer/duckdb/ast/DuckDBUnaryPrefixOperator.java b/src/sqlancer/duckdb/ast/DuckDBUnaryPrefixOperator.java new file mode 100644 index 000000000..95c8b4990 --- /dev/null +++ b/src/sqlancer/duckdb/ast/DuckDBUnaryPrefixOperator.java @@ -0,0 +1,11 @@ +package sqlancer.duckdb.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; + +public class DuckDBUnaryPrefixOperator extends NewUnaryPrefixOperatorNode + implements DuckDBExpression { + public DuckDBUnaryPrefixOperator(DuckDBExpression expr, BinaryOperatorNode.Operator operator) { + super(expr, operator); + } +} diff --git a/src/sqlancer/duckdb/gen/DuckDBExpressionGenerator.java b/src/sqlancer/duckdb/gen/DuckDBExpressionGenerator.java index eeba94027..278685a46 100644 --- a/src/sqlancer/duckdb/gen/DuckDBExpressionGenerator.java +++ b/src/sqlancer/duckdb/gen/DuckDBExpressionGenerator.java @@ -3,33 +3,44 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.ast.BinaryOperatorNode.Operator; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.NewBetweenOperatorNode; -import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.NewCaseOperatorNode; -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.NewInOperatorNode; -import sqlancer.common.ast.newast.NewOrderingTerm; import sqlancer.common.ast.newast.NewOrderingTerm.Ordering; -import sqlancer.common.ast.newast.NewTernaryNode; import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; -import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; -import sqlancer.common.ast.newast.Node; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; import sqlancer.duckdb.DuckDBSchema.DuckDBColumn; import sqlancer.duckdb.DuckDBSchema.DuckDBCompositeDataType; import sqlancer.duckdb.DuckDBSchema.DuckDBDataType; +import sqlancer.duckdb.DuckDBSchema.DuckDBTable; +import sqlancer.duckdb.DuckDBToStringVisitor; +import sqlancer.duckdb.ast.DuckDBBetweenOperator; +import sqlancer.duckdb.ast.DuckDBBinaryOperator; +import sqlancer.duckdb.ast.DuckDBCaseOperator; +import sqlancer.duckdb.ast.DuckDBColumnReference; import sqlancer.duckdb.ast.DuckDBConstant; import sqlancer.duckdb.ast.DuckDBExpression; - -public final class DuckDBExpressionGenerator extends UntypedExpressionGenerator, DuckDBColumn> { +import sqlancer.duckdb.ast.DuckDBFunction; +import sqlancer.duckdb.ast.DuckDBInOperator; +import sqlancer.duckdb.ast.DuckDBJoin; +import sqlancer.duckdb.ast.DuckDBOrderingTerm; +import sqlancer.duckdb.ast.DuckDBPostFixText; +import sqlancer.duckdb.ast.DuckDBSelect; +import sqlancer.duckdb.ast.DuckDBTableReference; +import sqlancer.duckdb.ast.DuckDBTernary; + +public final class DuckDBExpressionGenerator extends UntypedExpressionGenerator + implements NoRECGenerator, + TLPWhereGenerator { private final DuckDBGlobalState globalState; + private List tables; public DuckDBExpressionGenerator(DuckDBGlobalState globalState) { this.globalState = globalState; @@ -41,14 +52,14 @@ private enum Expression { } @Override - protected Node generateExpression(int depth) { + protected DuckDBExpression generateExpression(int depth) { if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { return generateLeafNode(); } if (allowAggregates && Randomly.getBoolean()) { DuckDBAggregateFunction aggregate = DuckDBAggregateFunction.getRandom(); allowAggregates = false; - return new NewFunctionNode<>(generateExpressions(aggregate.getNrArgs(), depth + 1), aggregate); + return new DuckDBFunction<>(generateExpressions(aggregate.getNrArgs(), depth + 1), aggregate); } List possibleOptions = new ArrayList<>(Arrays.asList(Expression.values())); if (!globalState.getDbmsSpecificOptions().testCollate) { @@ -78,44 +89,41 @@ protected Node generateExpression(int depth) { Expression expr = Randomly.fromList(possibleOptions); switch (expr) { case COLLATE: - return new NewUnaryPostfixOperatorNode(generateExpression(depth + 1), + return new sqlancer.duckdb.ast.DuckDBUnaryPostfixOperator(generateExpression(depth + 1), DuckDBCollate.getRandom()); case UNARY_PREFIX: - return new NewUnaryPrefixOperatorNode(generateExpression(depth + 1), + return new sqlancer.duckdb.ast.DuckDBUnaryPrefixOperator(generateExpression(depth + 1), DuckDBUnaryPrefixOperator.getRandom()); case UNARY_POSTFIX: - return new NewUnaryPostfixOperatorNode(generateExpression(depth + 1), + return new sqlancer.duckdb.ast.DuckDBUnaryPostfixOperator(generateExpression(depth + 1), DuckDBUnaryPostfixOperator.getRandom()); case BINARY_COMPARISON: Operator op = DuckDBBinaryComparisonOperator.getRandom(); - return new NewBinaryOperatorNode(generateExpression(depth + 1), - generateExpression(depth + 1), op); + return new DuckDBBinaryOperator(generateExpression(depth + 1), generateExpression(depth + 1), op); case BINARY_LOGICAL: op = DuckDBBinaryLogicalOperator.getRandom(); - return new NewBinaryOperatorNode(generateExpression(depth + 1), - generateExpression(depth + 1), op); + return new DuckDBBinaryOperator(generateExpression(depth + 1), generateExpression(depth + 1), op); case BINARY_ARITHMETIC: - return new NewBinaryOperatorNode(generateExpression(depth + 1), - generateExpression(depth + 1), DuckDBBinaryArithmeticOperator.getRandom()); + return new DuckDBBinaryOperator(generateExpression(depth + 1), generateExpression(depth + 1), + DuckDBBinaryArithmeticOperator.getRandom()); case CAST: return new DuckDBCastOperation(generateExpression(depth + 1), DuckDBCompositeDataType.getRandomWithoutNull()); case FUNC: DBFunction func = DBFunction.getRandom(); - return new NewFunctionNode(generateExpressions(func.getNrArgs()), func); + return new DuckDBFunction<>(generateExpressions(func.getNrArgs()), func); case BETWEEN: - return new NewBetweenOperatorNode(generateExpression(depth + 1), - generateExpression(depth + 1), generateExpression(depth + 1), Randomly.getBoolean()); + return new DuckDBBetweenOperator(generateExpression(depth + 1), generateExpression(depth + 1), + generateExpression(depth + 1), Randomly.getBoolean()); case IN: - return new NewInOperatorNode(generateExpression(depth + 1), + return new DuckDBInOperator(generateExpression(depth + 1), generateExpressions(Randomly.smallNumber() + 1, depth + 1), Randomly.getBoolean()); case CASE: int nr = Randomly.smallNumber() + 1; - return new NewCaseOperatorNode(generateExpression(depth + 1), - generateExpressions(nr, depth + 1), generateExpressions(nr, depth + 1), - generateExpression(depth + 1)); + return new DuckDBCaseOperator(generateExpression(depth + 1), generateExpressions(nr, depth + 1), + generateExpressions(nr, depth + 1), generateExpression(depth + 1)); case LIKE_ESCAPE: - return new NewTernaryNode(generateExpression(depth + 1), generateExpression(depth + 1), + return new DuckDBTernary(generateExpression(depth + 1), generateExpression(depth + 1), generateExpression(depth + 1), "LIKE", "ESCAPE"); default: throw new AssertionError(); @@ -123,13 +131,13 @@ protected Node generateExpression(int depth) { } @Override - protected Node generateColumn() { + protected DuckDBExpression generateColumn() { DuckDBColumn column = Randomly.fromList(columns); - return new ColumnReferenceNode(column); + return new DuckDBColumnReference(column); } @Override - public Node generateConstant() { + public DuckDBExpression generateConstant() { if (Randomly.getBooleanWithSmallProbability()) { return DuckDBConstant.createNullConstant(); } @@ -171,21 +179,22 @@ public Node generateConstant() { } @Override - public List> generateOrderBys() { - List> expr = super.generateOrderBys(); - List> newExpr = new ArrayList<>(expr.size()); - for (Node curExpr : expr) { + public List generateOrderBys() { + List expr = super.generateOrderBys(); + List newExpr = new ArrayList<>(expr.size()); + for (DuckDBExpression curExpr : expr) { if (Randomly.getBoolean()) { - curExpr = new NewOrderingTerm<>(curExpr, Ordering.getRandom()); + curExpr = new DuckDBOrderingTerm(curExpr, Ordering.getRandom()); } newExpr.add(curExpr); } return newExpr; }; - public static class DuckDBCastOperation extends NewUnaryPostfixOperatorNode { + public static class DuckDBCastOperation extends NewUnaryPostfixOperatorNode + implements DuckDBExpression { - public DuckDBCastOperation(Node expr, DuckDBCompositeDataType type) { + public DuckDBCastOperation(DuckDBExpression expr, DuckDBCompositeDataType type) { super(expr, new Operator() { @Override @@ -422,25 +431,96 @@ public String getTextRepresentation() { } - public NewFunctionNode generateArgsForAggregate( - DuckDBAggregateFunction aggregateFunction) { - return new NewFunctionNode( - generateExpressions(aggregateFunction.getNrArgs()), aggregateFunction); + public DuckDBFunction generateArgsForAggregate(DuckDBAggregateFunction aggregateFunction) { + return new DuckDBFunction<>(generateExpressions(aggregateFunction.getNrArgs()), aggregateFunction); } - public Node generateAggregate() { + public DuckDBExpression generateAggregate() { DuckDBAggregateFunction aggrFunc = DuckDBAggregateFunction.getRandom(); return generateArgsForAggregate(aggrFunc); } @Override - public Node negatePredicate(Node predicate) { - return new NewUnaryPrefixOperatorNode<>(predicate, DuckDBUnaryPrefixOperator.NOT); + public DuckDBExpression negatePredicate(DuckDBExpression predicate) { + return new sqlancer.duckdb.ast.DuckDBUnaryPrefixOperator(predicate, DuckDBUnaryPrefixOperator.NOT); + } + + @Override + public DuckDBExpression isNull(DuckDBExpression expr) { + return new sqlancer.duckdb.ast.DuckDBUnaryPostfixOperator(expr, DuckDBUnaryPostfixOperator.IS_NULL); + } + + @Override + public DuckDBExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; } @Override - public Node isNull(Node expr) { - return new NewUnaryPostfixOperatorNode<>(expr, DuckDBUnaryPostfixOperator.IS_NULL); + public DuckDBExpression generateBooleanExpression() { + return generateExpression(); } + @Override + public DuckDBSelect generateSelect() { + return new DuckDBSelect(); + } + + @Override + public List getRandomJoinClauses() { + List tableList = tables.stream().map(t -> new DuckDBTableReference(t)) + .collect(Collectors.toList()); + List joins = DuckDBJoin.getJoins(tableList, globalState); + tables = tableList.stream().map(t -> t.getTable()).collect(Collectors.toList()); + return joins; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new DuckDBTableReference(t)).collect(Collectors.toList()); + } + + @Override + public String generateOptimizedQueryString(DuckDBSelect select, DuckDBExpression whereCondition, + boolean shouldUseAggregate) { + List allColumns = columns.stream().map((c) -> new DuckDBColumnReference(c)) + .collect(Collectors.toList()); + if (shouldUseAggregate) { + DuckDBFunction aggr = new DuckDBFunction<>( + Arrays.asList(new DuckDBColumnReference( + new DuckDBColumn("*", new DuckDBCompositeDataType(DuckDBDataType.INT, 0), false, false))), + DuckDBAggregateFunction.COUNT); + select.setFetchColumns(Arrays.asList(aggr)); + } else { + select.setFetchColumns(allColumns); + if (Randomly.getBooleanWithSmallProbability()) { + select.setOrderByClauses(generateOrderBys()); + } + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(DuckDBSelect select, DuckDBExpression whereCondition) { + DuckDBExpression asText = new DuckDBPostFixText(new DuckDBCastOperation( + new DuckDBPostFixText(whereCondition, + " IS NOT NULL AND " + DuckDBToStringVisitor.asString(whereCondition)), + new DuckDBCompositeDataType(DuckDBDataType.INT, 8)), "as count"); + select.setFetchColumns(Arrays.asList(asText)); + + return "SELECT SUM(count) FROM (" + select.asString() + ") as res"; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (Randomly.getBoolean()) { + return List.of(new DuckDBColumnReference(new DuckDBColumn("*", null, false, false))); + } + return Randomly.nonEmptySubset(columns).stream().map(c -> new DuckDBColumnReference(c)) + .collect(Collectors.toList()); + } } diff --git a/src/sqlancer/duckdb/gen/DuckDBIndexGenerator.java b/src/sqlancer/duckdb/gen/DuckDBIndexGenerator.java index eec5384aa..6c50b204d 100644 --- a/src/sqlancer/duckdb/gen/DuckDBIndexGenerator.java +++ b/src/sqlancer/duckdb/gen/DuckDBIndexGenerator.java @@ -3,14 +3,11 @@ import java.util.List; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; import sqlancer.duckdb.DuckDBSchema.DuckDBColumn; import sqlancer.duckdb.DuckDBSchema.DuckDBTable; -import sqlancer.duckdb.DuckDBToStringVisitor; -import sqlancer.duckdb.ast.DuckDBExpression; public final class DuckDBIndexGenerator { @@ -22,11 +19,11 @@ public static SQLQueryAdapter getQuery(DuckDBGlobalState globalState) { StringBuilder sb = new StringBuilder(); sb.append("CREATE "); if (Randomly.getBoolean()) { - errors.add("Cant create unique index, table contains duplicate data on indexed column(s)"); + errors.add("Data contains duplicates on indexed column(s)"); sb.append("UNIQUE "); } sb.append("INDEX "); - sb.append(Randomly.fromOptions("i0", "i1", "i2", "i3", "i4")); // cannot query this information + sb.append(globalState.getSchema().getFreeIndexName()); sb.append(" ON "); DuckDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); sb.append(table.getName()); @@ -43,15 +40,8 @@ public static SQLQueryAdapter getQuery(DuckDBGlobalState globalState) { } } sb.append(")"); - if (Randomly.getBoolean()) { - sb.append(" WHERE "); - Node expr = new DuckDBExpressionGenerator(globalState).setColumns(table.getColumns()) - .generateExpression(); - sb.append(DuckDBToStringVisitor.asString(expr)); - } - errors.add("already exists!"); if (globalState.getDbmsSpecificOptions().testRowid) { - errors.add("Cannot create an index on the rowid!"); + errors.add("cannot create an index on the rowid"); } return new SQLQueryAdapter(sb.toString(), errors, true); } diff --git a/src/sqlancer/duckdb/gen/DuckDBInsertGenerator.java b/src/sqlancer/duckdb/gen/DuckDBInsertGenerator.java index 1e5eb60b9..6793d2b51 100644 --- a/src/sqlancer/duckdb/gen/DuckDBInsertGenerator.java +++ b/src/sqlancer/duckdb/gen/DuckDBInsertGenerator.java @@ -29,7 +29,7 @@ public static SQLQueryAdapter getQuery(DuckDBGlobalState globalState) { private SQLQueryAdapter generate() { sb.append("INSERT INTO "); DuckDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); - List columns = table.getRandomNonEmptyColumnSubset(); + List columns = table.getRandomNonEmptyColumnSubsetFilter(p -> !p.getName().equals("rowid")); sb.append(table.getName()); sb.append("("); sb.append(columns.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); @@ -41,7 +41,7 @@ private SQLQueryAdapter generate() { } @Override - protected void insertValue(DuckDBColumn tiDBColumn) { + protected void insertValue(DuckDBColumn columnDuckDB) { // TODO: select a more meaningful value if (Randomly.getBooleanWithRatherLowProbability()) { sb.append("DEFAULT"); diff --git a/src/sqlancer/duckdb/gen/DuckDBRandomQuerySynthesizer.java b/src/sqlancer/duckdb/gen/DuckDBRandomQuerySynthesizer.java index 29af42e10..d88d4f0b8 100644 --- a/src/sqlancer/duckdb/gen/DuckDBRandomQuerySynthesizer.java +++ b/src/sqlancer/duckdb/gen/DuckDBRandomQuerySynthesizer.java @@ -5,8 +5,6 @@ import java.util.stream.Collectors; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; import sqlancer.duckdb.DuckDBSchema.DuckDBTable; import sqlancer.duckdb.DuckDBSchema.DuckDBTables; @@ -14,6 +12,7 @@ import sqlancer.duckdb.ast.DuckDBExpression; import sqlancer.duckdb.ast.DuckDBJoin; import sqlancer.duckdb.ast.DuckDBSelect; +import sqlancer.duckdb.ast.DuckDBTableReference; public final class DuckDBRandomQuerySynthesizer { @@ -28,10 +27,10 @@ public static DuckDBSelect generateSelect(DuckDBGlobalState globalState, int nrC // TODO: distinct // select.setDistinct(Randomly.getBoolean()); // boolean allowAggregates = Randomly.getBooleanWithSmallProbability(); - List> columns = new ArrayList<>(); + List columns = new ArrayList<>(); for (int i = 0; i < nrColumns; i++) { // if (allowAggregates && Randomly.getBoolean()) { - Node expression = gen.generateExpression(); + DuckDBExpression expression = gen.generateExpression(); columns.add(expression); // } else { // columns.add(gen()); @@ -39,16 +38,16 @@ public static DuckDBSelect generateSelect(DuckDBGlobalState globalState, int nrC } select.setFetchColumns(columns); List tables = targetTables.getTables(); - List> tableList = tables.stream() - .map(t -> new TableReferenceNode(t)).collect(Collectors.toList()); - List> joins = DuckDBJoin.getJoins(tableList, globalState); + List tableList = tables.stream().map(t -> new DuckDBTableReference(t)) + .collect(Collectors.toList()); + List joins = DuckDBJoin.getJoins(tableList, globalState); select.setJoinList(joins.stream().collect(Collectors.toList())); select.setFromList(tableList.stream().collect(Collectors.toList())); if (Randomly.getBoolean()) { select.setWhereClause(gen.generateExpression()); } if (Randomly.getBoolean()) { - select.setOrderByExpressions(gen.generateOrderBys()); + select.setOrderByClauses(gen.generateOrderBys()); } if (Randomly.getBoolean()) { select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); diff --git a/src/sqlancer/duckdb/gen/DuckDBTableGenerator.java b/src/sqlancer/duckdb/gen/DuckDBTableGenerator.java index d3cd27fde..ea6d3537f 100644 --- a/src/sqlancer/duckdb/gen/DuckDBTableGenerator.java +++ b/src/sqlancer/duckdb/gen/DuckDBTableGenerator.java @@ -5,7 +5,6 @@ import java.util.stream.Collectors; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; import sqlancer.common.gen.UntypedExpressionGenerator; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; @@ -27,8 +26,8 @@ public SQLQueryAdapter getQuery(DuckDBGlobalState globalState) { sb.append(tableName); sb.append("("); List columns = getNewColumns(); - UntypedExpressionGenerator, DuckDBColumn> gen = new DuckDBExpressionGenerator( - globalState).setColumns(columns); + UntypedExpressionGenerator gen = new DuckDBExpressionGenerator(globalState) + .setColumns(columns); for (int i = 0; i < columns.size(); i++) { if (i != 0) { sb.append(", "); diff --git a/src/sqlancer/duckdb/gen/DuckDBUpdateGenerator.java b/src/sqlancer/duckdb/gen/DuckDBUpdateGenerator.java index ba8c867fd..b4ffd0140 100644 --- a/src/sqlancer/duckdb/gen/DuckDBUpdateGenerator.java +++ b/src/sqlancer/duckdb/gen/DuckDBUpdateGenerator.java @@ -3,8 +3,7 @@ import java.util.List; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.gen.AbstractUpdateGenerator; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.duckdb.DuckDBErrors; import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; @@ -13,36 +12,41 @@ import sqlancer.duckdb.DuckDBToStringVisitor; import sqlancer.duckdb.ast.DuckDBExpression; -public final class DuckDBUpdateGenerator { +public final class DuckDBUpdateGenerator extends AbstractUpdateGenerator { - private DuckDBUpdateGenerator() { + private final DuckDBGlobalState globalState; + private DuckDBExpressionGenerator gen; + + private DuckDBUpdateGenerator(DuckDBGlobalState globalState) { + this.globalState = globalState; } public static SQLQueryAdapter getQuery(DuckDBGlobalState globalState) { - StringBuilder sb = new StringBuilder("UPDATE "); - ExpectedErrors errors = new ExpectedErrors(); + return new DuckDBUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { DuckDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyColumnSubsetFilter(p -> !p.getName().equals("rowid")); + gen = new DuckDBExpressionGenerator(globalState).setColumns(table.getColumns()); + sb.append("UPDATE "); sb.append(table.getName()); - DuckDBExpressionGenerator gen = new DuckDBExpressionGenerator(globalState).setColumns(table.getColumns()); sb.append(" SET "); - List columns = table.getRandomNonEmptyColumnSubset(); - for (int i = 0; i < columns.size(); i++) { - if (i != 0) { - sb.append(", "); - } - sb.append(columns.get(i).getName()); - sb.append("="); - Node expr; - if (Randomly.getBooleanWithSmallProbability()) { - expr = gen.generateExpression(); - DuckDBErrors.addExpressionErrors(errors); - } else { - expr = gen.generateConstant(); - } - sb.append(DuckDBToStringVisitor.asString(expr)); - } + updateColumns(columns); DuckDBErrors.addInsertErrors(errors); return new SQLQueryAdapter(sb.toString(), errors); } + @Override + protected void updateValue(DuckDBColumn column) { + DuckDBExpression expr; + if (Randomly.getBooleanWithSmallProbability()) { + expr = gen.generateExpression(); + DuckDBErrors.addExpressionErrors(errors); + } else { + expr = gen.generateConstant(); + } + sb.append(DuckDBToStringVisitor.asString(expr)); + } + } diff --git a/src/sqlancer/duckdb/test/DuckDBNoRECOracle.java b/src/sqlancer/duckdb/test/DuckDBNoRECOracle.java deleted file mode 100644 index 550043615..000000000 --- a/src/sqlancer/duckdb/test/DuckDBNoRECOracle.java +++ /dev/null @@ -1,137 +0,0 @@ -package sqlancer.duckdb.test; - -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import sqlancer.IgnoreMeException; -import sqlancer.Randomly; -import sqlancer.SQLConnection; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.NewPostfixTextNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; -import sqlancer.common.oracle.NoRECBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.common.query.SQLQueryAdapter; -import sqlancer.common.query.SQLancerResultSet; -import sqlancer.duckdb.DuckDBErrors; -import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; -import sqlancer.duckdb.DuckDBSchema; -import sqlancer.duckdb.DuckDBSchema.DuckDBColumn; -import sqlancer.duckdb.DuckDBSchema.DuckDBCompositeDataType; -import sqlancer.duckdb.DuckDBSchema.DuckDBDataType; -import sqlancer.duckdb.DuckDBSchema.DuckDBTable; -import sqlancer.duckdb.DuckDBSchema.DuckDBTables; -import sqlancer.duckdb.DuckDBToStringVisitor; -import sqlancer.duckdb.ast.DuckDBExpression; -import sqlancer.duckdb.ast.DuckDBJoin; -import sqlancer.duckdb.ast.DuckDBSelect; -import sqlancer.duckdb.gen.DuckDBExpressionGenerator; -import sqlancer.duckdb.gen.DuckDBExpressionGenerator.DuckDBCastOperation; - -public class DuckDBNoRECOracle extends NoRECBase implements TestOracle { - - private final DuckDBSchema s; - - public DuckDBNoRECOracle(DuckDBGlobalState globalState) { - super(globalState); - this.s = globalState.getSchema(); - DuckDBErrors.addExpressionErrors(errors); - } - - @Override - public void check() throws SQLException { - DuckDBTables randomTables = s.getRandomTableNonEmptyTables(); - List columns = randomTables.getColumns(); - DuckDBExpressionGenerator gen = new DuckDBExpressionGenerator(state).setColumns(columns); - Node randomWhereCondition = gen.generateExpression(); - List tables = randomTables.getTables(); - List> tableList = tables.stream() - .map(t -> new TableReferenceNode(t)).collect(Collectors.toList()); - List> joins = DuckDBJoin.getJoins(tableList, state); - int secondCount = getSecondQuery(tableList.stream().collect(Collectors.toList()), randomWhereCondition, joins); - int firstCount = getFirstQueryCount(con, tableList.stream().collect(Collectors.toList()), columns, - randomWhereCondition, joins); - if (firstCount == -1 || secondCount == -1) { - throw new IgnoreMeException(); - } - if (firstCount != secondCount) { - throw new AssertionError( - optimizedQueryString + "; -- " + firstCount + "\n" + unoptimizedQueryString + " -- " + secondCount); - } - } - - private int getSecondQuery(List> tableList, Node randomWhereCondition, - List> joins) throws SQLException { - DuckDBSelect select = new DuckDBSelect(); - // select.setGroupByClause(groupBys); - // DuckDBExpression isTrue = DuckDBPostfixOperation.create(randomWhereCondition, - // PostfixOperator.IS_TRUE); - Node asText = new NewPostfixTextNode<>(new DuckDBCastOperation( - new NewPostfixTextNode(randomWhereCondition, - " IS NOT NULL AND " + DuckDBToStringVisitor.asString(randomWhereCondition)), - new DuckDBCompositeDataType(DuckDBDataType.INT, 8)), "as count"); - select.setFetchColumns(Arrays.asList(asText)); - select.setFromList(tableList); - // select.setSelectType(SelectType.ALL); - select.setJoinList(joins); - int secondCount = 0; - unoptimizedQueryString = "SELECT SUM(count) FROM (" + DuckDBToStringVisitor.asString(select) + ") as res"; - errors.add("canceling statement due to statement timeout"); - SQLQueryAdapter q = new SQLQueryAdapter(unoptimizedQueryString, errors); - SQLancerResultSet rs; - try { - rs = q.executeAndGetLogged(state); - } catch (Exception e) { - throw new AssertionError(unoptimizedQueryString, e); - } - if (rs == null) { - return -1; - } - if (rs.next()) { - secondCount += rs.getLong(1); - } - rs.close(); - return secondCount; - } - - private int getFirstQueryCount(SQLConnection con, List> tableList, - List columns, Node randomWhereCondition, List> joins) - throws SQLException { - DuckDBSelect select = new DuckDBSelect(); - // select.setGroupByClause(groupBys); - // DuckDBAggregate aggr = new DuckDBAggregate( - List> allColumns = columns.stream() - .map((c) -> new ColumnReferenceNode(c)).collect(Collectors.toList()); - // DuckDBAggregateFunction.COUNT); - // select.setFetchColumns(Arrays.asList(aggr)); - select.setFetchColumns(allColumns); - select.setFromList(tableList); - select.setWhereClause(randomWhereCondition); - if (Randomly.getBooleanWithSmallProbability()) { - select.setOrderByExpressions(new DuckDBExpressionGenerator(state).setColumns(columns).generateOrderBys()); - } - // select.setSelectType(SelectType.ALL); - select.setJoinList(joins); - int firstCount = 0; - try (Statement stat = con.createStatement()) { - optimizedQueryString = DuckDBToStringVisitor.asString(select); - if (options.logEachSelect()) { - logger.writeCurrent(optimizedQueryString); - } - try (ResultSet rs = stat.executeQuery(optimizedQueryString)) { - while (rs.next()) { - firstCount++; - } - } - } catch (SQLException e) { - throw new IgnoreMeException(); - } - return firstCount; - } - -} diff --git a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningAggregateTester.java b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningAggregateTester.java index a6b8d8f27..e82f97db6 100644 --- a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningAggregateTester.java +++ b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningAggregateTester.java @@ -8,12 +8,7 @@ import sqlancer.ComparatorHelper; import sqlancer.IgnoreMeException; import sqlancer.Randomly; -import sqlancer.common.ast.newast.NewAliasNode; -import sqlancer.common.ast.newast.NewBinaryOperatorNode; import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; -import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; -import sqlancer.common.ast.newast.Node; import sqlancer.common.oracle.TestOracle; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLancerResultSet; @@ -22,7 +17,10 @@ import sqlancer.duckdb.DuckDBSchema.DuckDBCompositeDataType; import sqlancer.duckdb.DuckDBSchema.DuckDBDataType; import sqlancer.duckdb.DuckDBToStringVisitor; +import sqlancer.duckdb.ast.DuckDBAlias; +import sqlancer.duckdb.ast.DuckDBBinaryOperator; import sqlancer.duckdb.ast.DuckDBExpression; +import sqlancer.duckdb.ast.DuckDBFunction; import sqlancer.duckdb.ast.DuckDBSelect; import sqlancer.duckdb.gen.DuckDBExpressionGenerator.DuckDBAggregateFunction; import sqlancer.duckdb.gen.DuckDBExpressionGenerator.DuckDBBinaryArithmeticOperator; @@ -30,7 +28,8 @@ import sqlancer.duckdb.gen.DuckDBExpressionGenerator.DuckDBUnaryPostfixOperator; import sqlancer.duckdb.gen.DuckDBExpressionGenerator.DuckDBUnaryPrefixOperator; -public class DuckDBQueryPartitioningAggregateTester extends DuckDBQueryPartitioningBase implements TestOracle { +public class DuckDBQueryPartitioningAggregateTester extends DuckDBQueryPartitioningBase + implements TestOracle { private String firstResult; private String secondResult; @@ -48,16 +47,15 @@ public void check() throws SQLException { DuckDBAggregateFunction aggregateFunction = Randomly.fromOptions(DuckDBAggregateFunction.MAX, DuckDBAggregateFunction.MIN, DuckDBAggregateFunction.SUM, DuckDBAggregateFunction.COUNT, DuckDBAggregateFunction.AVG/* , DuckDBAggregateFunction.STDDEV_POP */); - NewFunctionNode aggregate = gen - .generateArgsForAggregate(aggregateFunction); - List> fetchColumns = new ArrayList<>(); + DuckDBFunction aggregate = gen.generateArgsForAggregate(aggregateFunction); + List fetchColumns = new ArrayList<>(); fetchColumns.add(aggregate); while (Randomly.getBooleanWithRatherLowProbability()) { fetchColumns.add(gen.generateAggregate()); } select.setFetchColumns(Arrays.asList(aggregate)); if (Randomly.getBooleanWithRatherLowProbability()) { - select.setOrderByExpressions(gen.generateOrderBys()); + select.setOrderByClauses(gen.generateOrderBys()); } originalQuery = DuckDBToStringVisitor.asString(select); firstResult = getAggregateResult(originalQuery); @@ -77,15 +75,15 @@ public void check() throws SQLException { } - private String createMetamorphicUnionQuery(DuckDBSelect select, - NewFunctionNode aggregate, List> from) { + private String createMetamorphicUnionQuery(DuckDBSelect select, DuckDBFunction aggregate, + List from) { String metamorphicQuery; - Node whereClause = gen.generateExpression(); - Node negatedClause = new NewUnaryPrefixOperatorNode<>(whereClause, + DuckDBExpression whereClause = gen.generateExpression(); + DuckDBExpression negatedClause = new sqlancer.duckdb.ast.DuckDBUnaryPrefixOperator(whereClause, DuckDBUnaryPrefixOperator.NOT); - Node notNullClause = new NewUnaryPostfixOperatorNode<>(whereClause, + DuckDBExpression notNullClause = new sqlancer.duckdb.ast.DuckDBUnaryPostfixOperator(whereClause, DuckDBUnaryPostfixOperator.IS_NULL); - List> mappedAggregate = mapped(aggregate); + List mappedAggregate = mapped(aggregate); DuckDBSelect leftSelect = getSelect(mappedAggregate, from, whereClause, select.getJoinList()); DuckDBSelect middleSelect = getSelect(mappedAggregate, from, negatedClause, select.getJoinList()); DuckDBSelect rightSelect = getSelect(mappedAggregate, from, notNullClause, select.getJoinList()); @@ -119,7 +117,7 @@ private String getAggregateResult(String queryString) throws SQLException { } } - private List> mapped(NewFunctionNode aggregate) { + private List mapped(DuckDBFunction aggregate) { DuckDBCastOperation count; switch (aggregate.getFunc()) { case COUNT: @@ -128,21 +126,19 @@ private List> mapped(NewFunctionNode sum = new NewFunctionNode<>(aggregate.getArgs(), + DuckDBFunction sum = new DuckDBFunction<>(aggregate.getArgs(), DuckDBAggregateFunction.SUM); - count = new DuckDBCastOperation(new NewFunctionNode<>(aggregate.getArgs(), DuckDBAggregateFunction.COUNT), + count = new DuckDBCastOperation(new DuckDBFunction<>(aggregate.getArgs(), DuckDBAggregateFunction.COUNT), new DuckDBCompositeDataType(DuckDBDataType.FLOAT, 8)); return aliasArgs(Arrays.asList(sum, count)); case STDDEV_POP: - NewFunctionNode sumSquared = new NewFunctionNode<>( - Arrays.asList(new NewBinaryOperatorNode<>(aggregate.getArgs().get(0), aggregate.getArgs().get(0), + DuckDBFunction sumSquared = new DuckDBFunction<>( + Arrays.asList(new DuckDBBinaryOperator(aggregate.getArgs().get(0), aggregate.getArgs().get(0), DuckDBBinaryArithmeticOperator.MULT)), DuckDBAggregateFunction.SUM); - count = new DuckDBCastOperation( - new NewFunctionNode(aggregate.getArgs(), - DuckDBAggregateFunction.COUNT), + count = new DuckDBCastOperation(new DuckDBFunction<>(aggregate.getArgs(), DuckDBAggregateFunction.COUNT), new DuckDBCompositeDataType(DuckDBDataType.FLOAT, 8)); - NewFunctionNode avg = new NewFunctionNode<>(aggregate.getArgs(), + DuckDBFunction avg = new DuckDBFunction<>(aggregate.getArgs(), DuckDBAggregateFunction.AVG); return aliasArgs(Arrays.asList(sumSquared, count, avg)); default: @@ -150,11 +146,11 @@ private List> mapped(NewFunctionNode> aliasArgs(List> originalAggregateArgs) { - List> args = new ArrayList<>(); + private List aliasArgs(List originalAggregateArgs) { + List args = new ArrayList<>(); int i = 0; - for (Node expr : originalAggregateArgs) { - args.add(new NewAliasNode(expr, "agg" + i++)); + for (DuckDBExpression expr : originalAggregateArgs) { + args.add(new DuckDBAlias(expr, "agg" + i++)); } return args; } @@ -172,8 +168,8 @@ private String getOuterAggregateFunction(NewFunctionNode> aggregates, List> from, - Node whereClause, List> joinList) { + private DuckDBSelect getSelect(List aggregates, List from, + DuckDBExpression whereClause, List joinList) { DuckDBSelect leftSelect = new DuckDBSelect(); leftSelect.setFetchColumns(aggregates); leftSelect.setFromList(from); diff --git a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningBase.java b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningBase.java index 08c862973..1e75be7b3 100644 --- a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningBase.java +++ b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningBase.java @@ -3,13 +3,9 @@ import java.sql.SQLException; import java.util.ArrayList; import java.util.List; -import java.util.Objects; import java.util.stream.Collectors; import sqlancer.Randomly; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; import sqlancer.common.gen.ExpressionGenerator; import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; import sqlancer.common.oracle.TestOracle; @@ -19,13 +15,15 @@ import sqlancer.duckdb.DuckDBSchema.DuckDBColumn; import sqlancer.duckdb.DuckDBSchema.DuckDBTable; import sqlancer.duckdb.DuckDBSchema.DuckDBTables; +import sqlancer.duckdb.ast.DuckDBColumnReference; import sqlancer.duckdb.ast.DuckDBExpression; import sqlancer.duckdb.ast.DuckDBJoin; import sqlancer.duckdb.ast.DuckDBSelect; +import sqlancer.duckdb.ast.DuckDBTableReference; import sqlancer.duckdb.gen.DuckDBExpressionGenerator; -public class DuckDBQueryPartitioningBase - extends TernaryLogicPartitioningOracleBase, DuckDBGlobalState> implements TestOracle { +public class DuckDBQueryPartitioningBase extends TernaryLogicPartitioningOracleBase + implements TestOracle { DuckDBSchema s; DuckDBTables targetTables; @@ -37,15 +35,6 @@ public DuckDBQueryPartitioningBase(DuckDBGlobalState state) { DuckDBErrors.addExpressionErrors(errors); } - public static String canonicalizeResultValue(String value) { - // Rule: -0.0 should be canonicalized to 0.0 - if (Objects.equals(value, "-0.0")) { - return "0.0"; - } - - return value; - } - @Override public void check() throws SQLException { s = state.getSchema(); @@ -55,27 +44,27 @@ public void check() throws SQLException { select = new DuckDBSelect(); select.setFetchColumns(generateFetchColumns()); List tables = targetTables.getTables(); - List> tableList = tables.stream() - .map(t -> new TableReferenceNode(t)).collect(Collectors.toList()); - List> joins = DuckDBJoin.getJoins(tableList, state); + List tableList = tables.stream().map(t -> new DuckDBTableReference(t)) + .collect(Collectors.toList()); + List joins = DuckDBJoin.getJoins(tableList, state); select.setJoinList(joins.stream().collect(Collectors.toList())); select.setFromList(tableList.stream().collect(Collectors.toList())); select.setWhereClause(null); } - List> generateFetchColumns() { - List> columns = new ArrayList<>(); + List generateFetchColumns() { + List columns = new ArrayList<>(); if (Randomly.getBoolean()) { - columns.add(new ColumnReferenceNode<>(new DuckDBColumn("*", null, false, false))); + columns.add(new DuckDBColumnReference(new DuckDBColumn("*", null, false, false))); } else { - columns = Randomly.nonEmptySubset(targetTables.getColumns()).stream() - .map(c -> new ColumnReferenceNode(c)).collect(Collectors.toList()); + columns = Randomly.nonEmptySubset(targetTables.getColumns()).stream().map(c -> new DuckDBColumnReference(c)) + .collect(Collectors.toList()); } return columns; } @Override - protected ExpressionGenerator> getGen() { + protected ExpressionGenerator getGen() { return gen; } diff --git a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningDistinctTester.java b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningDistinctTester.java index d85d2a2d5..70cd731d6 100644 --- a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningDistinctTester.java +++ b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningDistinctTester.java @@ -38,7 +38,7 @@ public void check() throws SQLException { List secondResultSet = ComparatorHelper.getCombinedResultSetNoDuplicates(firstQueryString, secondQueryString, thirdQueryString, combinedString, true, state, errors); ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state, DuckDBQueryPartitioningBase::canonicalizeResultValue); + state, ComparatorHelper::canonicalizeResultValue); } } diff --git a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningGroupByTester.java b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningGroupByTester.java index fab480fed..c40a63a18 100644 --- a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningGroupByTester.java +++ b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningGroupByTester.java @@ -7,12 +7,10 @@ import sqlancer.ComparatorHelper; import sqlancer.Randomly; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.Node; import sqlancer.duckdb.DuckDBErrors; import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; -import sqlancer.duckdb.DuckDBSchema.DuckDBColumn; import sqlancer.duckdb.DuckDBToStringVisitor; +import sqlancer.duckdb.ast.DuckDBColumnReference; import sqlancer.duckdb.ast.DuckDBExpression; public class DuckDBQueryPartitioningGroupByTester extends DuckDBQueryPartitioningBase { @@ -41,13 +39,13 @@ public void check() throws SQLException { List secondResultSet = ComparatorHelper.getCombinedResultSetNoDuplicates(firstQueryString, secondQueryString, thirdQueryString, combinedString, true, state, errors); ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state, DuckDBQueryPartitioningBase::canonicalizeResultValue); + state, ComparatorHelper::canonicalizeResultValue); } @Override - List> generateFetchColumns() { - return Randomly.nonEmptySubset(targetTables.getColumns()).stream() - .map(c -> new ColumnReferenceNode(c)).collect(Collectors.toList()); + List generateFetchColumns() { + return Randomly.nonEmptySubset(targetTables.getColumns()).stream().map(c -> new DuckDBColumnReference(c)) + .collect(Collectors.toList()); } } diff --git a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningHavingTester.java b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningHavingTester.java index fc8843729..b0ff0a44c 100644 --- a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningHavingTester.java +++ b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningHavingTester.java @@ -7,14 +7,14 @@ import sqlancer.ComparatorHelper; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; import sqlancer.common.oracle.TestOracle; import sqlancer.duckdb.DuckDBErrors; import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; import sqlancer.duckdb.DuckDBToStringVisitor; import sqlancer.duckdb.ast.DuckDBExpression; -public class DuckDBQueryPartitioningHavingTester extends DuckDBQueryPartitioningBase implements TestOracle { +public class DuckDBQueryPartitioningHavingTester extends DuckDBQueryPartitioningBase + implements TestOracle { public DuckDBQueryPartitioningHavingTester(DuckDBGlobalState state) { super(state); @@ -29,7 +29,7 @@ public void check() throws SQLException { } boolean orderBy = Randomly.getBoolean(); if (orderBy) { - select.setOrderByExpressions(gen.generateOrderBys()); + select.setOrderByClauses(gen.generateOrderBys()); } select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); select.setHavingClause(null); @@ -46,16 +46,16 @@ public void check() throws SQLException { List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, thirdQueryString, combinedString, !orderBy, state, errors); ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state, DuckDBQueryPartitioningBase::canonicalizeResultValue); + state, ComparatorHelper::canonicalizeResultValue); } @Override - protected Node generatePredicate() { + protected DuckDBExpression generatePredicate() { return gen.generateHavingClause(); } @Override - List> generateFetchColumns() { + List generateFetchColumns() { return Arrays.asList(gen.generateHavingClause()); } diff --git a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningWhereTester.java b/src/sqlancer/duckdb/test/DuckDBQueryPartitioningWhereTester.java deleted file mode 100644 index 412ab6ffd..000000000 --- a/src/sqlancer/duckdb/test/DuckDBQueryPartitioningWhereTester.java +++ /dev/null @@ -1,45 +0,0 @@ -package sqlancer.duckdb.test; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.Randomly; -import sqlancer.duckdb.DuckDBErrors; -import sqlancer.duckdb.DuckDBProvider.DuckDBGlobalState; -import sqlancer.duckdb.DuckDBToStringVisitor; - -public class DuckDBQueryPartitioningWhereTester extends DuckDBQueryPartitioningBase { - - public DuckDBQueryPartitioningWhereTester(DuckDBGlobalState state) { - super(state); - DuckDBErrors.addGroupByErrors(errors); - } - - @Override - public void check() throws SQLException { - super.check(); - select.setWhereClause(null); - String originalQueryString = DuckDBToStringVisitor.asString(select); - - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - boolean orderBy = Randomly.getBooleanWithRatherLowProbability(); - if (orderBy) { - select.setOrderByExpressions(gen.generateOrderBys()); - } - select.setWhereClause(predicate); - String firstQueryString = DuckDBToStringVisitor.asString(select); - select.setWhereClause(negatedPredicate); - String secondQueryString = DuckDBToStringVisitor.asString(select); - select.setWhereClause(isNullPredicate); - String thirdQueryString = DuckDBToStringVisitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, !orderBy, state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state, DuckDBQueryPartitioningBase::canonicalizeResultValue); - } - -} diff --git a/src/sqlancer/h2/H2CastNode.java b/src/sqlancer/h2/H2CastNode.java deleted file mode 100644 index 9d279c8dd..000000000 --- a/src/sqlancer/h2/H2CastNode.java +++ /dev/null @@ -1,24 +0,0 @@ -package sqlancer.h2; - -import sqlancer.common.ast.newast.Node; -import sqlancer.h2.H2Schema.H2CompositeDataType; - -public class H2CastNode implements Node { - - private final Node expression; - private final H2CompositeDataType type; - - public H2CastNode(Node expression, H2CompositeDataType type) { - this.expression = expression; - this.type = type; - } - - public Node getExpression() { - return expression; - } - - public H2CompositeDataType getType() { - return type; - } - -} diff --git a/src/sqlancer/h2/H2Errors.java b/src/sqlancer/h2/H2Errors.java index aa1cba9cd..5846c1a88 100644 --- a/src/sqlancer/h2/H2Errors.java +++ b/src/sqlancer/h2/H2Errors.java @@ -1,5 +1,8 @@ package sqlancer.h2; +import java.util.ArrayList; +import java.util.List; + import sqlancer.common.query.ExpectedErrors; public final class H2Errors { @@ -7,7 +10,8 @@ public final class H2Errors { private H2Errors() { } - public static void addInsertErrors(ExpectedErrors errors) { + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); errors.add("NULL not allowed for column"); errors.add("Unique index or primary key violation"); errors.add("Data conversion error"); @@ -16,9 +20,15 @@ public static void addInsertErrors(ExpectedErrors errors) { errors.add("Referential integrity constraint violation"); errors.add("Check constraint invalid"); errors.add("Check constraint violation"); + return errors; } - public static void addExpressionErrors(ExpectedErrors errors) { + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); + } + + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); errors.add("java.lang.ArithmeticException: BigInteger would overflow supported range"); errors.add("Value too long for column"); errors.add("Numeric value out of range"); @@ -41,12 +51,23 @@ public static void addExpressionErrors(ExpectedErrors errors) { errors.add(/* precision */ "must be between"); // TRUNCATE_VALUE errors.add("Cannot parse \"TIMESTAMP\" constant"); // TRUNCATE errors.add("Invalid parameter count for \"TRUNC\", expected count: \"1\""); // TRUNCATE + return errors; } - public static void addDeleteErrors(ExpectedErrors errors) { + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + public static List getDeleteErrors() { + ArrayList errors = new ArrayList<>(); errors.add("No default value is set for column"); // referential actions errors.add("Referential integrity constraint violation"); errors.add("NULL not allowed for column"); + return errors; + } + + public static void addDeleteErrors(ExpectedErrors errors) { + errors.addAll(getDeleteErrors()); } } diff --git a/src/sqlancer/h2/H2Expression.java b/src/sqlancer/h2/H2Expression.java deleted file mode 100644 index 61b8c9658..000000000 --- a/src/sqlancer/h2/H2Expression.java +++ /dev/null @@ -1,5 +0,0 @@ -package sqlancer.h2; - -public interface H2Expression { - -} diff --git a/src/sqlancer/h2/H2ExpressionGenerator.java b/src/sqlancer/h2/H2ExpressionGenerator.java index d04251409..3d816292c 100644 --- a/src/sqlancer/h2/H2ExpressionGenerator.java +++ b/src/sqlancer/h2/H2ExpressionGenerator.java @@ -1,25 +1,37 @@ package sqlancer.h2; +import java.util.List; +import java.util.stream.Collectors; + import sqlancer.Randomly; import sqlancer.common.ast.BinaryOperatorNode.Operator; -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.NewBetweenOperatorNode; -import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.NewCaseOperatorNode; -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.NewInOperatorNode; -import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; -import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; -import sqlancer.common.ast.newast.Node; +import sqlancer.common.gen.TLPWhereGenerator; import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; import sqlancer.h2.H2Provider.H2GlobalState; import sqlancer.h2.H2Schema.H2Column; import sqlancer.h2.H2Schema.H2CompositeDataType; import sqlancer.h2.H2Schema.H2DataType; - -public class H2ExpressionGenerator extends UntypedExpressionGenerator, H2Column> { +import sqlancer.h2.H2Schema.H2Table; +import sqlancer.h2.ast.H2BetweenOperation; +import sqlancer.h2.ast.H2BinaryOperation; +import sqlancer.h2.ast.H2CaseOperation; +import sqlancer.h2.ast.H2CastNode; +import sqlancer.h2.ast.H2ColumnReference; +import sqlancer.h2.ast.H2Constant; +import sqlancer.h2.ast.H2Expression; +import sqlancer.h2.ast.H2InOperation; +import sqlancer.h2.ast.H2Join; +import sqlancer.h2.ast.H2Select; +import sqlancer.h2.ast.H2TableReference; +import sqlancer.h2.ast.H2UnaryPostfixOperation; +import sqlancer.h2.ast.H2UnaryPrefixOperation; + +public class H2ExpressionGenerator extends UntypedExpressionGenerator + implements TLPWhereGenerator { private final H2GlobalState globalState; + private List tables; public H2ExpressionGenerator(H2GlobalState globalState) { this.globalState = globalState; @@ -31,7 +43,7 @@ private enum Expression { } @Override - protected Node generateExpression(int depth) { + protected H2Expression generateExpression(int depth) { if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { return generateLeafNode(); } @@ -39,37 +51,33 @@ protected Node generateExpression(int depth) { switch (expr) { case BINARY_COMPARISON: Operator op = H2BinaryComparisonOperator.getRandom(); - return new NewBinaryOperatorNode(generateExpression(depth + 1), generateExpression(depth + 1), - op); + return new H2BinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); case BINARY_LOGICAL: op = H2BinaryLogicalOperator.getRandom(); - return new NewBinaryOperatorNode(generateExpression(depth + 1), generateExpression(depth + 1), - op); + return new H2BinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); case UNARY_POSTFIX: op = H2UnaryPostfixOperator.getRandom(); - return new NewUnaryPostfixOperatorNode(generateExpression(depth + 1), op); + return new H2UnaryPostfixOperation(generateExpression(depth + 1), op); case UNARY_PREFIX: - return new NewUnaryPrefixOperatorNode(generateExpression(depth + 1), - H2UnaryPrefixOperator.getRandom()); + return new H2UnaryPrefixOperation(generateExpression(depth + 1), H2UnaryPrefixOperator.getRandom()); case IN: - return new NewInOperatorNode(generateExpression(depth + 1), + return new H2InOperation(generateExpression(depth + 1), generateExpressions(Randomly.smallNumber() + 1, depth + 1), Randomly.getBoolean()); case BETWEEN: - return new NewBetweenOperatorNode(generateExpression(depth + 1), - generateExpression(depth + 1), generateExpression(depth + 1), Randomly.getBoolean()); + return new H2BetweenOperation(generateExpression(depth + 1), generateExpression(depth + 1), + generateExpression(depth + 1), Randomly.getBoolean()); case CASE: int nr = Randomly.smallNumber() + 1; - return new NewCaseOperatorNode(generateExpression(depth + 1), - generateExpressions(nr, depth + 1), generateExpressions(nr, depth + 1), - generateExpression(depth + 1)); + return new H2CaseOperation(generateExpression(depth + 1), generateExpressions(nr, depth + 1), + generateExpressions(nr, depth + 1), generateExpression(depth + 1)); case BINARY_ARITHMETIC: - return new NewBinaryOperatorNode(generateExpression(depth + 1), generateExpression(depth + 1), + return new H2BinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), H2BinaryArithmeticOperator.getRandom()); case CAST: return new H2CastNode(generateExpression(depth + 1), H2CompositeDataType.getRandom()); case FUNCTION: H2Function func = H2Function.getRandom(); - return new NewFunctionNode(generateExpressions(func.getNrArgs()), func); + return new sqlancer.h2.ast.H2Function<>(generateExpressions(func.getNrArgs()), func); default: throw new AssertionError(); } @@ -203,12 +211,12 @@ public int getNrArgs() { } @Override - protected Node generateColumn() { - return new ColumnReferenceNode(Randomly.fromList(columns)); + protected H2Expression generateColumn() { + return new H2ColumnReference(Randomly.fromList(columns)); } @Override - public Node generateConstant() { + public H2Expression generateConstant() { if (Randomly.getBooleanWithSmallProbability()) { return H2Constant.createNullConstant(); } @@ -330,13 +338,54 @@ public String getTextRepresentation() { } @Override - public Node negatePredicate(Node predicate) { - return new NewUnaryPrefixOperatorNode<>(predicate, H2UnaryPrefixOperator.NOT); + public H2Expression negatePredicate(H2Expression predicate) { + return new H2UnaryPrefixOperation(predicate, H2UnaryPrefixOperator.NOT); + } + + @Override + public H2Expression isNull(H2Expression expr) { + return new H2UnaryPostfixOperation(expr, H2UnaryPostfixOperator.IS_NULL); } @Override - public Node isNull(Node expr) { - return new NewUnaryPostfixOperatorNode<>(expr, H2UnaryPostfixOperator.IS_NULL); + public TLPWhereGenerator setTablesAndColumns( + AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; } + @Override + public H2Expression generateBooleanExpression() { + return generateExpression(); + } + + @Override + public H2Select generateSelect() { + return new H2Select(); + } + + @Override + public List getRandomJoinClauses() { + List tableList = tables.stream().map(t -> new H2TableReference(t)) + .collect(Collectors.toList()); + List joins = H2Join.getJoins(tableList, globalState); + tables = tableList.stream().map(t -> t.getTable()).collect(Collectors.toList()); + return joins; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new H2TableReference(t)).collect(Collectors.toList()); + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy && Randomly.getBoolean()) { + return List.of(new H2ColumnReference(new H2Column("*", null))); + } + return Randomly.nonEmptySubset(this.columns).stream().map(c -> new H2ColumnReference(c)) + .collect(Collectors.toList()); + } } diff --git a/src/sqlancer/h2/H2InsertGenerator.java b/src/sqlancer/h2/H2InsertGenerator.java index cfb17846a..c4e559b0e 100644 --- a/src/sqlancer/h2/H2InsertGenerator.java +++ b/src/sqlancer/h2/H2InsertGenerator.java @@ -56,7 +56,7 @@ private SQLQueryAdapter generate() { } @Override - protected void insertValue(H2Column tiDBColumn) { + protected void insertValue(H2Column columnH2) { sb.append(H2ToStringVisitor.asString(gen.generateConstant())); } } diff --git a/src/sqlancer/h2/H2Options.java b/src/sqlancer/h2/H2Options.java index 45b47f203..985853e94 100644 --- a/src/sqlancer/h2/H2Options.java +++ b/src/sqlancer/h2/H2Options.java @@ -1,33 +1,15 @@ package sqlancer.h2; -import java.sql.SQLException; import java.util.Arrays; import java.util.List; import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.TestOracle; -import sqlancer.h2.H2Options.H2OracleFactory; -import sqlancer.h2.H2Provider.H2GlobalState; @Parameters(commandDescription = "H2") public class H2Options implements DBMSSpecificOptions { - public enum H2OracleFactory implements OracleFactory { - - TLP_WHERE { - - @Override - public TestOracle create(H2GlobalState globalState) throws SQLException { - return new H2QueryPartitioningWhereTester(globalState); - } - - }; - - } - @Override public List getTestOracleFactory() { return Arrays.asList(H2OracleFactory.TLP_WHERE); diff --git a/src/sqlancer/h2/H2OracleFactory.java b/src/sqlancer/h2/H2OracleFactory.java new file mode 100644 index 000000000..fbbaade05 --- /dev/null +++ b/src/sqlancer/h2/H2OracleFactory.java @@ -0,0 +1,23 @@ +package sqlancer.h2; + +import java.sql.SQLException; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; + +public enum H2OracleFactory implements OracleFactory { + + TLP_WHERE { + @Override + public TestOracle create(H2Provider.H2GlobalState globalState) throws SQLException { + H2ExpressionGenerator gen = new H2ExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(H2Errors.getExpressionErrors()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + + }; + +} diff --git a/src/sqlancer/h2/H2QueryPartitioningBase.java b/src/sqlancer/h2/H2QueryPartitioningBase.java deleted file mode 100644 index b15f1c312..000000000 --- a/src/sqlancer/h2/H2QueryPartitioningBase.java +++ /dev/null @@ -1,60 +0,0 @@ -package sqlancer.h2; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; - -import sqlancer.common.ast.newast.ColumnReferenceNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; -import sqlancer.common.gen.ExpressionGenerator; -import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.h2.H2Provider.H2GlobalState; -import sqlancer.h2.H2Schema.H2Column; -import sqlancer.h2.H2Schema.H2Table; -import sqlancer.h2.H2Schema.H2Tables; - -public class H2QueryPartitioningBase extends TernaryLogicPartitioningOracleBase, H2GlobalState> - implements TestOracle { - - H2Schema s; - H2Tables targetTables; - H2ExpressionGenerator gen; - H2Select select; - - public H2QueryPartitioningBase(H2GlobalState state) { - super(state); - H2Errors.addExpressionErrors(errors); - } - - @Override - public void check() throws SQLException { - s = state.getSchema(); - targetTables = s.getRandomTableNonEmptyTables(); - gen = new H2ExpressionGenerator(state).setColumns(targetTables.getColumns()); - initializeTernaryPredicateVariants(); - select = new H2Select(); - select.setFetchColumns(generateFetchColumns()); - List tables = targetTables.getTables(); - List> tableList = tables.stream() - .map(t -> new TableReferenceNode(t)).collect(Collectors.toList()); - List> joins = H2Join.getJoins(tableList, state); - select.setJoinList(joins.stream().collect(Collectors.toList())); - select.setFromList(tableList.stream().collect(Collectors.toList())); - select.setWhereClause(null); - } - - List> generateFetchColumns() { - List> columns = new ArrayList<>(); - columns.add(new ColumnReferenceNode<>(new H2Column("*", null))); - return columns; - } - - @Override - protected ExpressionGenerator> getGen() { - return gen; - } - -} diff --git a/src/sqlancer/h2/H2QueryPartitioningWhereTester.java b/src/sqlancer/h2/H2QueryPartitioningWhereTester.java deleted file mode 100644 index 8de4ecd37..000000000 --- a/src/sqlancer/h2/H2QueryPartitioningWhereTester.java +++ /dev/null @@ -1,42 +0,0 @@ -package sqlancer.h2; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.Randomly; -import sqlancer.h2.H2Provider.H2GlobalState; - -public class H2QueryPartitioningWhereTester extends H2QueryPartitioningBase { - - public H2QueryPartitioningWhereTester(H2GlobalState state) { - super(state); - } - - @Override - public void check() throws SQLException { - super.check(); - select.setWhereClause(null); - String originalQueryString = H2ToStringVisitor.asString(select); - - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - boolean orderBy = Randomly.getBooleanWithRatherLowProbability(); - if (orderBy) { - select.setOrderByExpressions(gen.generateOrderBys()); - } - select.setWhereClause(predicate); - String firstQueryString = H2ToStringVisitor.asString(select); - select.setWhereClause(negatedPredicate); - String secondQueryString = H2ToStringVisitor.asString(select); - select.setWhereClause(isNullPredicate); - String thirdQueryString = H2ToStringVisitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, !orderBy, state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state); - } - -} diff --git a/src/sqlancer/h2/H2RandomQuerySynthesizer.java b/src/sqlancer/h2/H2RandomQuerySynthesizer.java index 37942f0e5..7390b87e7 100644 --- a/src/sqlancer/h2/H2RandomQuerySynthesizer.java +++ b/src/sqlancer/h2/H2RandomQuerySynthesizer.java @@ -5,11 +5,14 @@ import java.util.stream.Collectors; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; import sqlancer.h2.H2Provider.H2GlobalState; import sqlancer.h2.H2Schema.H2Table; import sqlancer.h2.H2Schema.H2Tables; +import sqlancer.h2.ast.H2Constant; +import sqlancer.h2.ast.H2Expression; +import sqlancer.h2.ast.H2Join; +import sqlancer.h2.ast.H2Select; +import sqlancer.h2.ast.H2TableReference; public final class H2RandomQuerySynthesizer { @@ -20,23 +23,23 @@ public static H2Select generateSelect(H2GlobalState globalState, int nrColumns) H2Tables targetTables = globalState.getSchema().getRandomTableNonEmptyTables(); H2ExpressionGenerator gen = new H2ExpressionGenerator(globalState).setColumns(targetTables.getColumns()); H2Select select = new H2Select(); - List> columns = new ArrayList<>(); + List columns = new ArrayList<>(); for (int i = 0; i < nrColumns; i++) { - Node expression = gen.generateExpression(); + H2Expression expression = gen.generateExpression(); columns.add(expression); } select.setFetchColumns(columns); List tables = targetTables.getTables(); - List> tableList = tables.stream() - .map(t -> new TableReferenceNode(t)).collect(Collectors.toList()); - List> joins = H2Join.getJoins(tableList, globalState); + List tableList = tables.stream().map(t -> new H2TableReference(t)) + .collect(Collectors.toList()); + List joins = H2Join.getJoins(tableList, globalState); select.setJoinList(joins.stream().collect(Collectors.toList())); select.setFromList(tableList.stream().collect(Collectors.toList())); if (Randomly.getBoolean()) { select.setWhereClause(gen.generateExpression()); } if (Randomly.getBoolean()) { - select.setOrderByExpressions(gen.generateOrderBys()); + select.setOrderByClauses(gen.generateOrderBys()); } if (Randomly.getBoolean()) { select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); diff --git a/src/sqlancer/h2/H2Select.java b/src/sqlancer/h2/H2Select.java deleted file mode 100644 index 07191659d..000000000 --- a/src/sqlancer/h2/H2Select.java +++ /dev/null @@ -1,8 +0,0 @@ -package sqlancer.h2; - -import sqlancer.common.ast.SelectBase; -import sqlancer.common.ast.newast.Node; - -public class H2Select extends SelectBase> implements Node { - -} diff --git a/src/sqlancer/h2/H2ToStringVisitor.java b/src/sqlancer/h2/H2ToStringVisitor.java index cb86ede30..ab18ab235 100644 --- a/src/sqlancer/h2/H2ToStringVisitor.java +++ b/src/sqlancer/h2/H2ToStringVisitor.java @@ -1,12 +1,16 @@ package sqlancer.h2; import sqlancer.common.ast.newast.NewToStringVisitor; -import sqlancer.common.ast.newast.Node; +import sqlancer.h2.ast.H2CastNode; +import sqlancer.h2.ast.H2Constant; +import sqlancer.h2.ast.H2Expression; +import sqlancer.h2.ast.H2Join; +import sqlancer.h2.ast.H2Select; public class H2ToStringVisitor extends NewToStringVisitor { @Override - public void visitSpecific(Node expr) { + public void visitSpecific(H2Expression expr) { if (expr instanceof H2Constant) { visit((H2Constant) expr); } else if (expr instanceof H2Select) { @@ -33,11 +37,11 @@ private void visit(H2CastNode cast) { } private void visit(H2Join join) { - visit(join.getLeftTable()); + visit((H2Expression) join.getLeftTable()); sb.append(" "); sb.append(join.getJoinType()); sb.append(" JOIN "); - visit(join.getRightTable()); + visit((H2Expression) join.getRightTable()); if (join.getOnCondition() != null) { sb.append(" ON "); visit(join.getOnCondition()); @@ -67,9 +71,9 @@ public void visit(H2Select select) { sb.append(" HAVING "); visit(select.getHavingClause()); } - if (!select.getOrderByExpressions().isEmpty()) { + if (!select.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - visit(select.getOrderByExpressions()); + visit(select.getOrderByClauses()); } if (select.getLimitClause() != null) { sb.append(" LIMIT "); @@ -81,7 +85,7 @@ public void visit(H2Select select) { } } - public static String asString(Node expr) { + public static String asString(H2Expression expr) { H2ToStringVisitor visitor = new H2ToStringVisitor(); visitor.visit(expr); return visitor.get(); diff --git a/src/sqlancer/h2/H2UpdateGenerator.java b/src/sqlancer/h2/H2UpdateGenerator.java index 0dfc15938..158621409 100644 --- a/src/sqlancer/h2/H2UpdateGenerator.java +++ b/src/sqlancer/h2/H2UpdateGenerator.java @@ -3,33 +3,33 @@ import java.util.List; import sqlancer.Randomly; -import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.gen.AbstractUpdateGenerator; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.h2.H2Provider.H2GlobalState; import sqlancer.h2.H2Schema.H2Column; import sqlancer.h2.H2Schema.H2Table; -public final class H2UpdateGenerator { +public final class H2UpdateGenerator extends AbstractUpdateGenerator { - private H2UpdateGenerator() { + private final H2GlobalState globalState; + private H2ExpressionGenerator gen; + + private H2UpdateGenerator(H2GlobalState globalState) { + this.globalState = globalState; } public static SQLQueryAdapter getQuery(H2GlobalState globalState) { - StringBuilder sb = new StringBuilder("UPDATE "); - ExpectedErrors errors = new ExpectedErrors(); + return new H2UpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { H2Table table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyColumnSubset(); + gen = new H2ExpressionGenerator(globalState).setColumns(table.getColumns()); + sb.append("UPDATE "); sb.append(table.getName()); - H2ExpressionGenerator gen = new H2ExpressionGenerator(globalState).setColumns(table.getColumns()); sb.append(" SET "); - List columns = table.getRandomNonEmptyColumnSubset(); - for (int i = 0; i < columns.size(); i++) { - if (i != 0) { - sb.append(", "); - } - sb.append(columns.get(i).getName()); - sb.append("="); - sb.append(H2ToStringVisitor.asString(gen.generateConstant())); - } + updateColumns(columns); H2Errors.addInsertErrors(errors); H2Errors.addDeleteErrors(errors); if (Randomly.getBoolean()) { @@ -40,4 +40,9 @@ public static SQLQueryAdapter getQuery(H2GlobalState globalState) { return new SQLQueryAdapter(sb.toString(), errors); } + @Override + protected void updateValue(H2Column column) { + sb.append(H2ToStringVisitor.asString(gen.generateConstant())); + } + } diff --git a/src/sqlancer/h2/ast/H2BetweenOperation.java b/src/sqlancer/h2/ast/H2BetweenOperation.java new file mode 100644 index 000000000..77921d91e --- /dev/null +++ b/src/sqlancer/h2/ast/H2BetweenOperation.java @@ -0,0 +1,9 @@ +package sqlancer.h2.ast; + +import sqlancer.common.ast.newast.NewBetweenOperatorNode; + +public class H2BetweenOperation extends NewBetweenOperatorNode implements H2Expression { + public H2BetweenOperation(H2Expression left, H2Expression middle, H2Expression right, boolean isTrue) { + super(left, middle, right, isTrue); + } +} diff --git a/src/sqlancer/h2/ast/H2BinaryOperation.java b/src/sqlancer/h2/ast/H2BinaryOperation.java new file mode 100644 index 000000000..d523256fd --- /dev/null +++ b/src/sqlancer/h2/ast/H2BinaryOperation.java @@ -0,0 +1,10 @@ +package sqlancer.h2.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class H2BinaryOperation extends NewBinaryOperatorNode implements H2Expression { + public H2BinaryOperation(H2Expression left, H2Expression right, Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/h2/ast/H2CaseOperation.java b/src/sqlancer/h2/ast/H2CaseOperation.java new file mode 100644 index 000000000..75e51fc4d --- /dev/null +++ b/src/sqlancer/h2/ast/H2CaseOperation.java @@ -0,0 +1,12 @@ +package sqlancer.h2.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewCaseOperatorNode; + +public class H2CaseOperation extends NewCaseOperatorNode implements H2Expression { + public H2CaseOperation(H2Expression switchCondition, List conditions, List expressions, + H2Expression elseExpr) { + super(switchCondition, conditions, expressions, elseExpr); + } +} diff --git a/src/sqlancer/h2/ast/H2CastNode.java b/src/sqlancer/h2/ast/H2CastNode.java new file mode 100644 index 000000000..5e7602311 --- /dev/null +++ b/src/sqlancer/h2/ast/H2CastNode.java @@ -0,0 +1,23 @@ +package sqlancer.h2.ast; + +import sqlancer.h2.H2Schema.H2CompositeDataType; + +public class H2CastNode implements H2Expression { + + private final H2Expression expression; + private final H2CompositeDataType type; + + public H2CastNode(H2Expression expression, H2CompositeDataType type) { + this.expression = expression; + this.type = type; + } + + public H2Expression getExpression() { + return expression; + } + + public H2CompositeDataType getType() { + return type; + } + +} diff --git a/src/sqlancer/h2/ast/H2ColumnReference.java b/src/sqlancer/h2/ast/H2ColumnReference.java new file mode 100644 index 000000000..5864f0b53 --- /dev/null +++ b/src/sqlancer/h2/ast/H2ColumnReference.java @@ -0,0 +1,11 @@ +package sqlancer.h2.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.h2.H2Schema; + +public class H2ColumnReference extends ColumnReferenceNode implements H2Expression { + public H2ColumnReference(H2Schema.H2Column column) { + super(column); + } + +} diff --git a/src/sqlancer/h2/H2Constant.java b/src/sqlancer/h2/ast/H2Constant.java similarity index 83% rename from src/sqlancer/h2/H2Constant.java rename to src/sqlancer/h2/ast/H2Constant.java index 4390fed26..923bcb4a2 100644 --- a/src/sqlancer/h2/H2Constant.java +++ b/src/sqlancer/h2/ast/H2Constant.java @@ -1,8 +1,6 @@ -package sqlancer.h2; +package sqlancer.h2.ast; -import sqlancer.common.ast.newast.Node; - -public class H2Constant implements Node { +public class H2Constant implements H2Expression { private H2Constant() { } @@ -108,27 +106,27 @@ public String toString() { } - public static Node createIntConstant(long val) { + public static H2Expression createIntConstant(long val) { return new H2IntConstant(val); } - public static Node createNullConstant() { + public static H2Expression createNullConstant() { return new H2NullConstant(); } - public static Node createBoolConstant(boolean val) { + public static H2Expression createBoolConstant(boolean val) { return new H2BoolConstant(val); } - public static Node createStringConstant(String val) { + public static H2Expression createStringConstant(String val) { return new H2StringConstant(val); } - public static Node createDoubleConstant(double val) { + public static H2Expression createDoubleConstant(double val) { return new H2DoubleConstant(val); } - public static Node createBinaryConstant(long val) { + public static H2Expression createBinaryConstant(long val) { return new H2BinaryConstant(val); } diff --git a/src/sqlancer/h2/ast/H2Expression.java b/src/sqlancer/h2/ast/H2Expression.java new file mode 100644 index 000000000..afd114c91 --- /dev/null +++ b/src/sqlancer/h2/ast/H2Expression.java @@ -0,0 +1,8 @@ +package sqlancer.h2.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.h2.H2Schema.H2Column; + +public interface H2Expression extends Expression { + +} diff --git a/src/sqlancer/h2/ast/H2Function.java b/src/sqlancer/h2/ast/H2Function.java new file mode 100644 index 000000000..9442d6db7 --- /dev/null +++ b/src/sqlancer/h2/ast/H2Function.java @@ -0,0 +1,11 @@ +package sqlancer.h2.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class H2Function extends NewFunctionNode implements H2Expression { + public H2Function(List args, F func) { + super(args, func); + } +} diff --git a/src/sqlancer/h2/ast/H2InOperation.java b/src/sqlancer/h2/ast/H2InOperation.java new file mode 100644 index 000000000..1389fb18b --- /dev/null +++ b/src/sqlancer/h2/ast/H2InOperation.java @@ -0,0 +1,11 @@ +package sqlancer.h2.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; + +public class H2InOperation extends NewInOperatorNode implements H2Expression { + public H2InOperation(H2Expression left, List right, boolean isNegated) { + super(left, right, isNegated); + } +} diff --git a/src/sqlancer/h2/H2Join.java b/src/sqlancer/h2/ast/H2Join.java similarity index 55% rename from src/sqlancer/h2/H2Join.java rename to src/sqlancer/h2/ast/H2Join.java index 1b566c5f0..6aafd1de9 100644 --- a/src/sqlancer/h2/H2Join.java +++ b/src/sqlancer/h2/ast/H2Join.java @@ -1,21 +1,21 @@ -package sqlancer.h2; +package sqlancer.h2.ast; import java.util.ArrayList; import java.util.List; import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.common.ast.newast.Join; +import sqlancer.h2.H2ExpressionGenerator; import sqlancer.h2.H2Provider.H2GlobalState; import sqlancer.h2.H2Schema.H2Column; import sqlancer.h2.H2Schema.H2Table; -public class H2Join implements Node { +public class H2Join implements H2Expression, Join { - private final TableReferenceNode leftTable; - private final TableReferenceNode rightTable; + private final H2TableReference leftTable; + private final H2TableReference rightTable; private final JoinType joinType; - private final Node onCondition; + private H2Expression onCondition; public enum JoinType { INNER, CROSS, NATURAL, LEFT, RIGHT; @@ -25,20 +25,19 @@ public static JoinType getRandom() { } } - public H2Join(TableReferenceNode leftTable, - TableReferenceNode rightTable, JoinType joinType, - Node whereCondition) { + public H2Join(H2TableReference leftTable, H2TableReference rightTable, JoinType joinType, + H2Expression whereCondition) { this.leftTable = leftTable; this.rightTable = rightTable; this.joinType = joinType; this.onCondition = whereCondition; } - public TableReferenceNode getLeftTable() { + public H2TableReference getLeftTable() { return leftTable; } - public TableReferenceNode getRightTable() { + public H2TableReference getRightTable() { return rightTable; } @@ -46,16 +45,15 @@ public JoinType getJoinType() { return joinType; } - public Node getOnCondition() { + public H2Expression getOnCondition() { return onCondition; } - public static List> getJoins(List> tableList, - H2GlobalState globalState) { - List> joinExpressions = new ArrayList<>(); + public static List getJoins(List tableList, H2GlobalState globalState) { + List joinExpressions = new ArrayList<>(); while (tableList.size() >= 2 && Randomly.getBooleanWithRatherLowProbability()) { - TableReferenceNode leftTable = tableList.remove(0); - TableReferenceNode rightTable = tableList.remove(0); + H2TableReference leftTable = tableList.remove(0); + H2TableReference rightTable = tableList.remove(0); List columns = new ArrayList<>(leftTable.getTable().getColumns()); columns.addAll(rightTable.getTable().getColumns()); H2ExpressionGenerator joinGen = new H2ExpressionGenerator(globalState).setColumns(columns); @@ -83,24 +81,25 @@ public static List> getJoins(List left, - TableReferenceNode right, Node predicate) { + public static H2Join createRightOuterJoin(H2TableReference left, H2TableReference right, H2Expression predicate) { return new H2Join(left, right, JoinType.RIGHT, predicate); } - public static H2Join createLeftOuterJoin(TableReferenceNode left, - TableReferenceNode right, Node predicate) { + public static H2Join createLeftOuterJoin(H2TableReference left, H2TableReference right, H2Expression predicate) { return new H2Join(left, right, JoinType.LEFT, predicate); } - public static H2Join createInnerJoin(TableReferenceNode left, - TableReferenceNode right, Node predicate) { + public static H2Join createInnerJoin(H2TableReference left, H2TableReference right, H2Expression predicate) { return new H2Join(left, right, JoinType.INNER, predicate); } - public static Node createNaturalJoin(TableReferenceNode left, - TableReferenceNode right) { + public static H2Join createNaturalJoin(H2TableReference left, H2TableReference right) { return new H2Join(left, right, JoinType.NATURAL, null); } + @Override + public void setOnClause(H2Expression onClause) { + onCondition = onClause; + } + } diff --git a/src/sqlancer/h2/ast/H2Select.java b/src/sqlancer/h2/ast/H2Select.java new file mode 100644 index 000000000..e2f6f8519 --- /dev/null +++ b/src/sqlancer/h2/ast/H2Select.java @@ -0,0 +1,31 @@ +package sqlancer.h2.ast; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.h2.H2Schema.H2Column; +import sqlancer.h2.H2Schema.H2Table; +import sqlancer.h2.H2ToStringVisitor; + +public class H2Select extends SelectBase + implements H2Expression, Select { + + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (H2Expression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (H2Join) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return H2ToStringVisitor.asString(this); + } +} diff --git a/src/sqlancer/h2/ast/H2TableReference.java b/src/sqlancer/h2/ast/H2TableReference.java new file mode 100644 index 000000000..459337c41 --- /dev/null +++ b/src/sqlancer/h2/ast/H2TableReference.java @@ -0,0 +1,10 @@ +package sqlancer.h2.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.h2.H2Schema; + +public class H2TableReference extends TableReferenceNode implements H2Expression { + public H2TableReference(H2Schema.H2Table table) { + super(table); + } +} diff --git a/src/sqlancer/h2/ast/H2UnaryPostfixOperation.java b/src/sqlancer/h2/ast/H2UnaryPostfixOperation.java new file mode 100644 index 000000000..5c43effbe --- /dev/null +++ b/src/sqlancer/h2/ast/H2UnaryPostfixOperation.java @@ -0,0 +1,10 @@ +package sqlancer.h2.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; + +public class H2UnaryPostfixOperation extends NewUnaryPostfixOperatorNode implements H2Expression { + public H2UnaryPostfixOperation(H2Expression expr, BinaryOperatorNode.Operator op) { + super(expr, op); + } +} diff --git a/src/sqlancer/h2/ast/H2UnaryPrefixOperation.java b/src/sqlancer/h2/ast/H2UnaryPrefixOperation.java new file mode 100644 index 000000000..24085dcb6 --- /dev/null +++ b/src/sqlancer/h2/ast/H2UnaryPrefixOperation.java @@ -0,0 +1,10 @@ +package sqlancer.h2.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; + +public class H2UnaryPrefixOperation extends NewUnaryPrefixOperatorNode implements H2Expression { + public H2UnaryPrefixOperation(H2Expression expr, BinaryOperatorNode.Operator operator) { + super(expr, operator); + } +} diff --git a/src/sqlancer/hive/HiveBugs.java b/src/sqlancer/hive/HiveBugs.java new file mode 100644 index 000000000..43d42ce0b --- /dev/null +++ b/src/sqlancer/hive/HiveBugs.java @@ -0,0 +1,41 @@ +package sqlancer.hive; + +// do not make the fields final to avoid warnings +public final class HiveBugs { + + // Incorrect IS NULL evaluation for negation of string concatenation involving column references. + // -(c || 'x') evaluates to NULL at runtime, but IS NULL incorrectly returns false. + // The optimizer's nullability inference for GenericUDFOPNegative does not account for + // runtime conversion failures producing NULL from non-null input. + // Reproduce: CREATE TABLE t(c DOUBLE); INSERT INTO t VALUES(1.0); + // SELECT (-(c || 'x')) IS NULL FROM t; -- returns false, expected true + // Affects: 4.0.1, 4.2.0 + public static boolean bugNegationNullability = true; + + // Non-boolean expressions (CAST to non-boolean, FLOOR, ROUND, arithmetic) silently + // return 0 rows for all three TLP partitions when used as WHERE predicates. + // Hive requires BOOLEAN in WHERE but does not error; instead it returns empty results. + // Reproduce: CREATE TABLE t(c INT); INSERT INTO t VALUES(1); + // SELECT * FROM t WHERE FLOOR(1); -- returns 0 rows, expected 1 + // Affects: 4.0.1, 4.2.0 + public static boolean bugNonBooleanWhereClause = true; + + // IN operator with boolean sub-expressions involving IS NULL evaluates incorrectly, + // returning 0 rows for all three TLP partitions. + // Reproduce: CREATE TABLE t(c BOOLEAN); INSERT INTO t VALUES(true),(false); + // SELECT * FROM t WHERE (c != c) IN ((false) IS NULL); -- returns 0, expected 2 + // Affects: 4.0.1, 4.2.0 + public static boolean bugInBooleanEvaluation = true; + + // BETWEEN with mixed boolean/numeric types has incorrect TLP evaluation. + // The IS NULL partition misses rows due to wrong nullability inference. + // Reproduce: CREATE TABLE t(c DOUBLE); INSERT INTO t VALUES(0.5),(1.5); + // SELECT * FROM t WHERE (c NOT IN (true)) NOT BETWEEN 0.01 AND c; + // -- TLP partitions lose rows + // Affects: 4.0.1, 4.2.0 + public static boolean bugBetweenMixedTypes = true; + + private HiveBugs() { + } + +} diff --git a/src/sqlancer/hive/HiveErrors.java b/src/sqlancer/hive/HiveErrors.java new file mode 100644 index 000000000..81b0be668 --- /dev/null +++ b/src/sqlancer/hive/HiveErrors.java @@ -0,0 +1,40 @@ +package sqlancer.hive; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class HiveErrors { + + private HiveErrors() { + } + + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("cannot recognize input near"); + errors.add("Argument type mismatch"); + errors.add("Error while compiling statement"); + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("Either CHECK or NOT NULL constraint violated!"); + errors.add("Error running query"); + errors.add("is different from preceding arguments"); + + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); + } +} diff --git a/src/sqlancer/hive/HiveGlobalState.java b/src/sqlancer/hive/HiveGlobalState.java new file mode 100644 index 000000000..a362c1f88 --- /dev/null +++ b/src/sqlancer/hive/HiveGlobalState.java @@ -0,0 +1,11 @@ +package sqlancer.hive; + +import sqlancer.SQLGlobalState; + +public class HiveGlobalState extends SQLGlobalState { + + @Override + protected HiveSchema readSchema() throws Exception { + return HiveSchema.fromConnection(getConnection(), getDatabaseName()); + } +} diff --git a/src/sqlancer/hive/HiveOptions.java b/src/sqlancer/hive/HiveOptions.java new file mode 100644 index 000000000..ea55ac676 --- /dev/null +++ b/src/sqlancer/hive/HiveOptions.java @@ -0,0 +1,43 @@ +package sqlancer.hive; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; +import sqlancer.OracleFactory; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.hive.gen.HiveExpressionGenerator; + +@Parameters(separators = "=", commandDescription = "Hive (default port: " + HiveOptions.DEFAULT_PORT + + ", default host: " + HiveOptions.DEFAULT_HOST + ")") +public class HiveOptions implements DBMSSpecificOptions { + public static final String DEFAULT_HOST = "localhost"; + public static final int DEFAULT_PORT = 10000; + + @Parameter(names = "--oracle") + public List oracle = Arrays.asList(HiveOracleFactory.TLPWhere); + + public enum HiveOracleFactory implements OracleFactory { + TLPWhere { + @Override + public TestOracle create(HiveGlobalState globalState) throws SQLException { + HiveExpressionGenerator gen = new HiveExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(HiveErrors.getExpressionErrors()) + .build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }; + } + + @Override + public List getTestOracleFactory() { + return oracle; + } +} diff --git a/src/sqlancer/hive/HiveProvider.java b/src/sqlancer/hive/HiveProvider.java new file mode 100644 index 000000000..12798df93 --- /dev/null +++ b/src/sqlancer/hive/HiveProvider.java @@ -0,0 +1,119 @@ +package sqlancer.hive; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.hive.gen.HiveInsertGenerator; +import sqlancer.hive.gen.HiveTableGenerator; + +@AutoService(DatabaseProvider.class) +public class HiveProvider extends SQLProviderAdapter { + + public HiveProvider() { + super(HiveGlobalState.class, HiveOptions.class); + } + + public enum Action implements AbstractAction { + + INSERT(HiveInsertGenerator::getQuery); + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(HiveGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + + private static int mapActions(HiveGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + switch (a) { + case INSERT: + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + default: + throw new AssertionError(a); + } + } + + @Override + public void generateDatabase(HiveGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { + boolean success; + do { + String tableName = globalState.getSchema().getFreeTableName(); + SQLQueryAdapter qt = HiveTableGenerator.generate(globalState, tableName); + success = globalState.executeStatement(qt); + } while (!success); + } + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); // TODO + } + + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + HiveProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public SQLConnection createDatabase(HiveGlobalState globalState) throws SQLException { + String username = globalState.getOptions().getUserName(); + String password = globalState.getOptions().getPassword(); + String host = globalState.getOptions().getHost(); + int port = globalState.getOptions().getPort(); + if (host == null) { + host = HiveOptions.DEFAULT_HOST; + } + if (port == MainOptions.NO_SET_PORT) { + port = HiveOptions.DEFAULT_PORT; + } + + String databaseName = globalState.getDatabaseName(); + + String url = String.format("jdbc:hive2://%s:%d/%s", host, port, "default"); + Connection con = DriverManager.getConnection(url, username, password); + globalState.getState().logStatement("DROP DATABASE IF EXISTS " + databaseName + " CASCADE"); + globalState.getState().logStatement("CREATE DATABASE " + databaseName); + globalState.getState().logStatement("USE " + databaseName); + try (Statement s = con.createStatement()) { + s.execute("DROP DATABASE IF EXISTS " + databaseName + " CASCADE"); + } + try (Statement s = con.createStatement()) { + s.execute("CREATE DATABASE " + databaseName); + } + try (Statement s = con.createStatement()) { + s.execute("USE " + databaseName); + } + con.close(); + con = DriverManager + .getConnection(String.format("jdbc:hive2://%s:%d/%s", host, port, databaseName, username, password)); + + return new SQLConnection(con); + } + + @Override + public String getDBMSName() { + return "hive"; + } +} diff --git a/src/sqlancer/hive/HiveSchema.java b/src/sqlancer/hive/HiveSchema.java new file mode 100644 index 000000000..8733d5caa --- /dev/null +++ b/src/sqlancer/hive/HiveSchema.java @@ -0,0 +1,106 @@ +package sqlancer.hive; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; +import sqlancer.hive.HiveSchema.HiveTable; + +public class HiveSchema extends AbstractSchema { + + public enum HiveDataType { + + // TODO: support more types, e.g. TIMESTAMP, DATE, VARCHAR, CHAR, BINARY, ARRAY, MAP, STRUCT, UNIONTYPE... + STRING, INT, DOUBLE, BOOLEAN; + + public static HiveDataType getRandomType() { + return Randomly.fromList(Arrays.asList(values())); + } + } + + public static class HiveColumn extends AbstractTableColumn { + + public HiveColumn(String name, HiveTable table, HiveDataType type) { + super(name, table, type); + } + } + + public static class HiveTables extends AbstractTables { + + public HiveTables(List tables) { + super(tables); + } + } + + public static class HiveTable extends AbstractRelationalTable { + + public HiveTable(String name, List columns, boolean isView) { + super(name, columns, Collections.emptyList(), isView); + } + } + + public HiveSchema(List databaseTables) { + super(databaseTables); + } + + public static HiveSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + List databaseTables = new ArrayList<>(); + List tableNames = getTableNames(con); + for (String tableName : tableNames) { + List databaseColumns = getTableColumns(con, tableName); + boolean isView = matchesViewName(tableName); + HiveTable t = new HiveTable(tableName, databaseColumns, isView); + for (HiveColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + } + return new HiveSchema(databaseTables); + } + + private static List getTableNames(SQLConnection con) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + ResultSet tableRs = s.executeQuery("SHOW TABLES"); + while (tableRs.next()) { + String tableName = tableRs.getString(1); + tableNames.add(tableName); + } + } + return tableNames; + } + + private static List getTableColumns(SQLConnection con, String tableName) throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String.format("DESCRIBE %s", tableName))) { + while (rs.next()) { + String columnName = rs.getString("col_name"); + String dataType = rs.getString("data_type"); + HiveColumn c = new HiveColumn(columnName, null, getColumnType(dataType.toUpperCase())); + columns.add(c); + } + } + } + return columns; + } + + private static HiveDataType getColumnType(String typeString) { + return HiveDataType.valueOf(typeString.toUpperCase()); + } + + public HiveTables getRandomTableNonEmptyTables() { + return new HiveTables(Randomly.nonEmptySubset(getDatabaseTables())); + } +} diff --git a/src/sqlancer/hive/HiveToStringVisitor.java b/src/sqlancer/hive/HiveToStringVisitor.java new file mode 100644 index 000000000..bdcd31eaf --- /dev/null +++ b/src/sqlancer/hive/HiveToStringVisitor.java @@ -0,0 +1,115 @@ +package sqlancer.hive; + +import sqlancer.common.ast.newast.NewToStringVisitor; +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.hive.ast.HiveCastOperation; +import sqlancer.hive.ast.HiveConstant; +import sqlancer.hive.ast.HiveExpression; +import sqlancer.hive.ast.HiveJoin; +import sqlancer.hive.ast.HiveSelect; + +public class HiveToStringVisitor extends NewToStringVisitor { + + @Override + public void visitSpecific(HiveExpression expr) { + if (expr instanceof HiveConstant) { + visit((HiveConstant) expr); + } else if (expr instanceof HiveSelect) { + visit((HiveSelect) expr); + } else if (expr instanceof HiveJoin) { + visit((HiveJoin) expr); + } else if (expr instanceof HiveCastOperation) { + visit((HiveCastOperation) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + private void visit(HiveConstant constant) { + sb.append(constant.toString()); + } + + private void visit(HiveSelect select) { + sb.append("SELECT "); + if (select.isDistinct()) { + sb.append("DISTINCT "); + } + visit(select.getFetchColumns()); + sb.append(" FROM "); + visit(select.getFromList()); + if (!select.getFromList().isEmpty() && !select.getJoinList().isEmpty()) { + sb.append(", "); + } + if (!select.getJoinList().isEmpty()) { + visit(select.getJoinList()); + } + if (select.getWhereClause() != null) { + sb.append(" WHERE "); + visit(select.getWhereClause()); + } + if (!select.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(select.getGroupByExpressions()); + } + if (select.getHavingClause() != null) { + sb.append(" HAVING "); + visit(select.getHavingClause()); + } + if (!select.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(select.getOrderByClauses()); + } + if (select.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(select.getLimitClause()); + } + if (select.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(select.getOffsetClause()); + } + } + + private void visit(HiveJoin join) { + switch (join.getJoinType()) { + case INNER: + sb.append(" INNER JOIN "); + break; + case LEFT_OUTER: + sb.append(" LEFT JOIN "); + break; + case RIGHT_OUTER: + sb.append(" RIGHT JOIN "); + break; + case FULL_OUTER: + sb.append(" FULL JOIN "); + break; + case LEFT_SEMI: + sb.append(" LEFT SEMI JOIN "); + break; + case CROSS: + sb.append(" CROSS JOIN "); + break; + default: + throw new UnsupportedOperationException(); + } + visit((TableReferenceNode) join.getRightTable()); + if (join.getOnClause() != null) { + sb.append(" ON "); + visit(join.getOnClause()); + } + } + + private void visit(HiveCastOperation cast) { + sb.append("CAST("); + visit(cast.getExpression()); + sb.append(" AS "); + sb.append(cast.getType()); + sb.append(")"); + } + + public static String asString(HiveExpression expr) { + HiveToStringVisitor visitor = new HiveToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } +} diff --git a/src/sqlancer/hive/ast/HiveBetweenOperation.java b/src/sqlancer/hive/ast/HiveBetweenOperation.java new file mode 100644 index 000000000..26ec1d940 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveBetweenOperation.java @@ -0,0 +1,10 @@ +package sqlancer.hive.ast; + +import sqlancer.common.ast.newast.NewBetweenOperatorNode; + +public class HiveBetweenOperation extends NewBetweenOperatorNode implements HiveExpression { + + public HiveBetweenOperation(HiveExpression left, HiveExpression middle, HiveExpression right, boolean isTrue) { + super(left, middle, right, isTrue); + } +} diff --git a/src/sqlancer/hive/ast/HiveBinaryOperation.java b/src/sqlancer/hive/ast/HiveBinaryOperation.java new file mode 100644 index 000000000..f74d117d4 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveBinaryOperation.java @@ -0,0 +1,11 @@ +package sqlancer.hive.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class HiveBinaryOperation extends NewBinaryOperatorNode implements HiveExpression { + + public HiveBinaryOperation(HiveExpression left, HiveExpression right, Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/hive/ast/HiveCaseOperation.java b/src/sqlancer/hive/ast/HiveCaseOperation.java new file mode 100644 index 000000000..666070667 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveCaseOperation.java @@ -0,0 +1,13 @@ +package sqlancer.hive.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewCaseOperatorNode; + +public class HiveCaseOperation extends NewCaseOperatorNode implements HiveExpression { + + public HiveCaseOperation(HiveExpression switchCondition, List conditions, + List expressions, HiveExpression elseExpr) { + super(switchCondition, conditions, expressions, elseExpr); + } +} diff --git a/src/sqlancer/hive/ast/HiveCastOperation.java b/src/sqlancer/hive/ast/HiveCastOperation.java new file mode 100644 index 000000000..2d76ab4f2 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveCastOperation.java @@ -0,0 +1,25 @@ +package sqlancer.hive.ast; + +import sqlancer.hive.HiveSchema.HiveDataType; + +public class HiveCastOperation implements HiveExpression { + + private final HiveExpression expression; + private final HiveDataType type; + + public HiveCastOperation(HiveExpression expression, HiveDataType type) { + if (expression == null) { + throw new AssertionError(); + } + this.expression = expression; + this.type = type; + } + + public HiveExpression getExpression() { + return expression; + } + + public HiveDataType getType() { + return type; + } +} diff --git a/src/sqlancer/hive/ast/HiveColumnReference.java b/src/sqlancer/hive/ast/HiveColumnReference.java new file mode 100644 index 000000000..c3237955f --- /dev/null +++ b/src/sqlancer/hive/ast/HiveColumnReference.java @@ -0,0 +1,11 @@ +package sqlancer.hive.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.hive.HiveSchema.HiveColumn; + +public class HiveColumnReference extends ColumnReferenceNode implements HiveExpression { + + public HiveColumnReference(HiveColumn column) { + super(column); + } +} diff --git a/src/sqlancer/hive/ast/HiveConstant.java b/src/sqlancer/hive/ast/HiveConstant.java new file mode 100644 index 000000000..7f89997d2 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveConstant.java @@ -0,0 +1,192 @@ +package sqlancer.hive.ast; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; + +public abstract class HiveConstant implements HiveExpression { + + public boolean isNull() { + return false; + } + + public static class HiveNullConstant extends HiveConstant { + + @Override + public boolean isNull() { + return true; + } + + @Override + public String toString() { + return "NULL"; + } + } + + public static class HiveIntConstant extends HiveConstant { + + private final long value; + + public HiveIntConstant(long value) { + this.value = value; + } + + public long getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static class HiveDoubleConstant extends HiveConstant { + + private final double value; + + public HiveDoubleConstant(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "'+Inf'"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "'-Inf'"; + } + return String.valueOf(value); + } + } + + public static class HiveDecimalConstant extends HiveConstant { + + private final BigDecimal value; + + public HiveDecimalConstant(BigDecimal value) { + this.value = value; + } + + public BigDecimal getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static class HiveTimestampConstant extends HiveConstant { + + private final String textRepr; + + public HiveTimestampConstant(long value) { + Timestamp timestamp = new Timestamp(value); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + this.textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("TIMESTAMP '%s'", textRepr); + } + } + + public static class HiveDateConstant extends HiveConstant { + + private final String textRepr; + + public HiveDateConstant(long value) { + Timestamp timestamp = new Timestamp(value); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + this.textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("DATE '%s'", textRepr); + } + } + + public static class StringConstant extends HiveConstant { + + private final String value; + + public StringConstant(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "'" + value.replace("'", "''") + "'"; + } + } + + public static class HiveBooleanConstant extends HiveConstant { + + private final boolean value; + + public HiveBooleanConstant(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static HiveConstant createNullConstant() { + return new HiveNullConstant(); + } + + public static HiveConstant createIntConstant(long value) { + return new HiveIntConstant(value); + } + + public static HiveConstant createDoubleConstant(double value) { + return new HiveDoubleConstant(value); + } + + public static HiveConstant createDecimalConstant(BigDecimal value) { + return new HiveDecimalConstant(value); + } + + public static HiveConstant createTimestampConstant(long value) { + return new HiveTimestampConstant(value); + } + + public static HiveConstant createDateConstant(long value) { + return new HiveDateConstant(value); + } + + public static HiveConstant createStringConstant(String value) { + return new StringConstant(value); + } + + public static HiveConstant createBooleanConstant(boolean value) { + return new HiveBooleanConstant(value); + } +} diff --git a/src/sqlancer/hive/ast/HiveExpression.java b/src/sqlancer/hive/ast/HiveExpression.java new file mode 100644 index 000000000..40842e181 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveExpression.java @@ -0,0 +1,7 @@ +package sqlancer.hive.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.hive.HiveSchema.HiveColumn; + +public interface HiveExpression extends Expression { +} diff --git a/src/sqlancer/hive/ast/HiveFunction.java b/src/sqlancer/hive/ast/HiveFunction.java new file mode 100644 index 000000000..b3a34ae25 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveFunction.java @@ -0,0 +1,13 @@ +package sqlancer.hive.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class HiveFunction extends NewFunctionNode implements HiveExpression { + + public HiveFunction(List args, F func) { + super(args, func); + } + +} diff --git a/src/sqlancer/hive/ast/HiveInOperation.java b/src/sqlancer/hive/ast/HiveInOperation.java new file mode 100644 index 000000000..601bf5e19 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveInOperation.java @@ -0,0 +1,12 @@ +package sqlancer.hive.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; + +public class HiveInOperation extends NewInOperatorNode implements HiveExpression { + + public HiveInOperation(HiveExpression left, List right, boolean isNegated) { + super(left, right, isNegated); + } +} diff --git a/src/sqlancer/hive/ast/HiveJoin.java b/src/sqlancer/hive/ast/HiveJoin.java new file mode 100644 index 000000000..932ed9afa --- /dev/null +++ b/src/sqlancer/hive/ast/HiveJoin.java @@ -0,0 +1,48 @@ +package sqlancer.hive.ast; + +import sqlancer.common.ast.newast.Join; +import sqlancer.hive.HiveSchema.HiveColumn; +import sqlancer.hive.HiveSchema.HiveTable; + +public class HiveJoin implements HiveExpression, Join { + + private final HiveTableReference leftTable; + private final HiveTableReference rightTable; + private final JoinType joinType; + private HiveExpression onClause; + + // TODO: test map-join optimization + + public enum JoinType { + INNER, LEFT_OUTER, RIGHT_OUTER, FULL_OUTER, LEFT_SEMI, CROSS; + } + + public HiveJoin(HiveTableReference leftTable, HiveTableReference rightTable, JoinType joinType, + HiveExpression onClause) { + this.leftTable = leftTable; + this.rightTable = rightTable; + this.joinType = joinType; + this.onClause = onClause; + } + + public HiveTableReference getLeftTable() { + return leftTable; + } + + public HiveTableReference getRightTable() { + return rightTable; + } + + public JoinType getJoinType() { + return joinType; + } + + public HiveExpression getOnClause() { + return onClause; + } + + @Override + public void setOnClause(HiveExpression onClause) { + this.onClause = onClause; + } +} diff --git a/src/sqlancer/hive/ast/HiveOrderingTerm.java b/src/sqlancer/hive/ast/HiveOrderingTerm.java new file mode 100644 index 000000000..70fef52ad --- /dev/null +++ b/src/sqlancer/hive/ast/HiveOrderingTerm.java @@ -0,0 +1,10 @@ +package sqlancer.hive.ast; + +import sqlancer.common.ast.newast.NewOrderingTerm; + +public class HiveOrderingTerm extends NewOrderingTerm implements HiveExpression { + + public HiveOrderingTerm(HiveExpression expr, Ordering ordering) { + super(expr, ordering); + } +} diff --git a/src/sqlancer/hive/ast/HiveSelect.java b/src/sqlancer/hive/ast/HiveSelect.java new file mode 100644 index 000000000..8a0eddc06 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveSelect.java @@ -0,0 +1,41 @@ +package sqlancer.hive.ast; + +import java.util.List; + +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.hive.HiveSchema.HiveColumn; +import sqlancer.hive.HiveSchema.HiveTable; +import sqlancer.hive.HiveToStringVisitor; + +public class HiveSelect extends SelectBase + implements Select, HiveExpression { + + private boolean isDistinct; + + public void setDistinct(boolean isDistinct) { + this.isDistinct = isDistinct; + } + + public boolean isDistinct() { + return isDistinct; + } + + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (HiveExpression) e) + .collect(java.util.stream.Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (HiveJoin) e).collect(java.util.stream.Collectors.toList()); + } + + @Override + public String asString() { + return HiveToStringVisitor.asString(this); + } + +} diff --git a/src/sqlancer/hive/ast/HiveTableReference.java b/src/sqlancer/hive/ast/HiveTableReference.java new file mode 100644 index 000000000..7d23b4895 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveTableReference.java @@ -0,0 +1,13 @@ +package sqlancer.hive.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.hive.HiveSchema; + +public class HiveTableReference extends TableReferenceNode + implements HiveExpression { + + public HiveTableReference(HiveSchema.HiveTable table) { + super(table); + } + +} diff --git a/src/sqlancer/hive/ast/HiveUnaryPostfixOperation.java b/src/sqlancer/hive/ast/HiveUnaryPostfixOperation.java new file mode 100644 index 000000000..0461c5c73 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveUnaryPostfixOperation.java @@ -0,0 +1,12 @@ +package sqlancer.hive.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; + +public class HiveUnaryPostfixOperation extends NewUnaryPostfixOperatorNode implements HiveExpression { + + public HiveUnaryPostfixOperation(HiveExpression expr, Operator op) { + super(expr, op); + } + +} diff --git a/src/sqlancer/hive/ast/HiveUnaryPrefixOperation.java b/src/sqlancer/hive/ast/HiveUnaryPrefixOperation.java new file mode 100644 index 000000000..9fe286f82 --- /dev/null +++ b/src/sqlancer/hive/ast/HiveUnaryPrefixOperation.java @@ -0,0 +1,12 @@ +package sqlancer.hive.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; + +public class HiveUnaryPrefixOperation extends NewUnaryPrefixOperatorNode implements HiveExpression { + + public HiveUnaryPrefixOperation(HiveExpression expr, Operator op) { + super(expr, op); + } + +} diff --git a/src/sqlancer/hive/gen/HiveExpressionGenerator.java b/src/sqlancer/hive/gen/HiveExpressionGenerator.java new file mode 100644 index 000000000..92154873c --- /dev/null +++ b/src/sqlancer/hive/gen/HiveExpressionGenerator.java @@ -0,0 +1,369 @@ +package sqlancer.hive.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewOrderingTerm.Ordering; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; +import sqlancer.hive.HiveBugs; +import sqlancer.hive.HiveGlobalState; +import sqlancer.hive.HiveSchema.HiveColumn; +import sqlancer.hive.HiveSchema.HiveDataType; +import sqlancer.hive.HiveSchema.HiveTable; +import sqlancer.hive.ast.HiveBetweenOperation; +import sqlancer.hive.ast.HiveBinaryOperation; +import sqlancer.hive.ast.HiveCaseOperation; +import sqlancer.hive.ast.HiveCastOperation; +import sqlancer.hive.ast.HiveColumnReference; +import sqlancer.hive.ast.HiveConstant; +import sqlancer.hive.ast.HiveExpression; +import sqlancer.hive.ast.HiveFunction; +import sqlancer.hive.ast.HiveInOperation; +import sqlancer.hive.ast.HiveJoin; +import sqlancer.hive.ast.HiveOrderingTerm; +import sqlancer.hive.ast.HiveSelect; +import sqlancer.hive.ast.HiveTableReference; +import sqlancer.hive.ast.HiveUnaryPostfixOperation; +import sqlancer.hive.ast.HiveUnaryPrefixOperation; + +public class HiveExpressionGenerator extends UntypedExpressionGenerator + implements TLPWhereGenerator { + + private final HiveGlobalState globalState; + private List tables; + + private enum Expression { + // TODO: add or delete expressions. + UNARY_PREFIX, UNARY_POSTFIX, BINARY_COMPARISON, BINARY_LOGICAL, BINARY_ARITHMETIC, CAST, FUNC, BETWEEN, IN, + CASE; + } + + public HiveExpressionGenerator(HiveGlobalState globalState) { + this.globalState = globalState; + } + + @Override + public HiveExpression negatePredicate(HiveExpression predicate) { + return new HiveUnaryPrefixOperation(predicate, HiveUnaryPrefixOperator.NOT); + } + + @Override + public HiveExpression isNull(HiveExpression expr) { + return new HiveUnaryPostfixOperation(expr, HiveUnaryPostfixOperator.IS_NULL); + } + + @Override + protected HiveExpression generateExpression(int depth) { + // TODO: randomly cast some types like what PostgresExpressionGenerator does? + return generateExpressionInternal(depth); + } + + private HiveExpression generateExpressionInternal(int depth) throws AssertionError { + if (depth >= globalState.getOptions().getMaxExpressionDepth() + || Randomly.getBooleanWithRatherLowProbability()) { + return generateLeafNode(); + } + if (allowAggregates && Randomly.getBooleanWithRatherLowProbability()) { + allowAggregates = false; // aggregate function calls cannot be nested + HiveAggregateFunction aggregate = HiveAggregateFunction.getRandom(); + return new HiveFunction<>(generateExpressions(aggregate.getNrArgs(), depth + 1), aggregate); + } + + List possibleOptions = new ArrayList<>(Arrays.asList(Expression.values())); + if (HiveBugs.bugNonBooleanWhereClause) { + possibleOptions.remove(Expression.CAST); + possibleOptions.remove(Expression.FUNC); + possibleOptions.remove(Expression.BINARY_ARITHMETIC); + } + if (HiveBugs.bugInBooleanEvaluation) { + possibleOptions.remove(Expression.IN); + } + if (HiveBugs.bugBetweenMixedTypes) { + possibleOptions.remove(Expression.BETWEEN); + } + + Expression expr = Randomly.fromList(possibleOptions); + switch (expr) { + case UNARY_PREFIX: + HiveUnaryPrefixOperator prefixOp = HiveUnaryPrefixOperator.getRandom(); + if (HiveBugs.bugNegationNullability + && (prefixOp == HiveUnaryPrefixOperator.MINUS || prefixOp == HiveUnaryPrefixOperator.PLUS)) { + throw new IgnoreMeException(); + } + return new HiveUnaryPrefixOperation(generateExpression(depth + 1), prefixOp); + case UNARY_POSTFIX: + return new HiveUnaryPostfixOperation(generateExpression(depth + 1), HiveUnaryPostfixOperator.getRandom()); + case BINARY_COMPARISON: + Operator op = HiveBinaryComparisonOperator.getRandom(); + return new HiveBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); + case BINARY_LOGICAL: + op = HiveExpressionGenerator.HiveBinaryLogicalOperator.getRandom(); + return new HiveBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); + case BINARY_ARITHMETIC: + return new HiveBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), + HiveExpressionGenerator.HiveBinaryArithmeticOperator.getRandom()); + case CAST: + return new HiveCastOperation(generateExpression(depth + 1), HiveDataType.getRandomType()); + case FUNC: + HiveFunc func = HiveFunc.getRandom(); + return new HiveFunction<>(generateExpressions(func.getNrArgs()), func); + case BETWEEN: + return new HiveBetweenOperation(generateExpression(depth + 1), generateExpression(depth + 1), + generateExpression(depth + 1), Randomly.getBoolean()); + case IN: + return new HiveInOperation(generateExpression(depth + 1), + generateExpressions(Randomly.smallNumber() + 1, depth + 1), Randomly.getBoolean()); + case CASE: + int nr = Randomly.smallNumber() + 1; + return new HiveCaseOperation(generateExpression(depth + 1), generateExpressions(nr, depth + 1), + generateExpressions(nr, depth + 1), generateExpression(depth + 1)); + default: + throw new AssertionError(expr); + } + } + + @Override + public HiveExpression generateConstant() { + if (Randomly.getBooleanWithRatherLowProbability()) { + return HiveConstant.createNullConstant(); + } + HiveDataType[] values = HiveDataType.values(); + HiveDataType constantType = Randomly.fromOptions(values); + switch (constantType) { + case STRING: + return HiveConstant.createStringConstant(globalState.getRandomly().getString()); + case INT: + return HiveConstant.createIntConstant(globalState.getRandomly().getInteger()); + case DOUBLE: + return HiveConstant.createDoubleConstant(globalState.getRandomly().getDouble()); + case BOOLEAN: + return HiveConstant.createBooleanConstant(Randomly.getBoolean()); + default: + throw new AssertionError(constantType); + } + } + + @Override + protected HiveExpression generateColumn() { + HiveColumn column = Randomly.fromList(columns); + return new HiveColumnReference(column); + } + + @Override + public List generateOrderBys() { + List expr = super.generateOrderBys(); + List newExpr = new ArrayList<>(expr.size()); + for (HiveExpression curExpr : expr) { + if (Randomly.getBoolean()) { + curExpr = new HiveOrderingTerm(curExpr, Ordering.getRandom()); + } + newExpr.add(curExpr); + } + return newExpr; + } + + @Override + public HiveExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public HiveExpression generateBooleanExpression() { + return generateExpression(); + } + + @Override + public HiveSelect generateSelect() { + return new HiveSelect(); + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new HiveTableReference(t)).collect(Collectors.toList()); + } + + @Override + public List generateFetchColumns(boolean allowAggregates) { + if (Randomly.getBoolean()) { + return List.of(new HiveColumnReference(new HiveColumn("*", null, null))); + } + return Randomly.nonEmptySubset(columns).stream().map(c -> new HiveColumnReference(c)) + .collect(Collectors.toList()); + } + + @Override + public List getRandomJoinClauses() { + return List.of(); + } + + public enum HiveUnaryPrefixOperator implements Operator { + + // TODO: ~A (bitwise NOT) + NOT("NOT"), PLUS("+"), MINUS("-"); + + private String textRepr; + + HiveUnaryPrefixOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static HiveUnaryPrefixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum HiveUnaryPostfixOperator implements Operator { + + // TODO: A IS [NOT] (NULL|TRUE|FALSE)... + IS_NULL("IS NULL"), IS_NOT_NULL("IS NOT NULL"); + + private String textRepr; + + HiveUnaryPostfixOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static HiveUnaryPostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum HiveBinaryComparisonOperator implements Operator { + + EQUALS("="), GREATER(">"), GREATER_EQUALS(">="), SMALLER("<"), SMALLER_EQUALS("<="), NOT_EQUALS("!="), + LIKE("LIKE"), NOT_LIKE("NOT LIKE"), REGEXP("RLIKE"); + + private String textRepr; + + HiveBinaryComparisonOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static HiveBinaryComparisonOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum HiveBinaryLogicalOperator implements Operator { + + AND("AND"), OR("OR"); + + private String textRepr; + + HiveBinaryLogicalOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static HiveBinaryLogicalOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum HiveBinaryArithmeticOperator implements Operator { + + CONCAT("||"), ADD("+"), SUB("-"), MULT("*"), DIV("/"), MOD("%"), BITWISE_AND("&"), BITWISE_OR("|"), + BITWISE_XOR("^"); + + private String textRepr; + + HiveBinaryArithmeticOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static HiveBinaryArithmeticOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum HiveAggregateFunction { + COUNT(1), SUM(1), AVG(1), MIN(1), MAX(1), VARIANCE(1), VAR_SAMP(1), STDDEV_POP(1), STDDEV_SAMP(1), COVAR_POP(2), + COVAR_SAMP(2), CORR(2); + + private int nrArgs; + + HiveAggregateFunction(int nrArgs) { + this.nrArgs = nrArgs; + } + + public static HiveAggregateFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + return nrArgs; + } + } + + // TODO: test all Hive default functions... + public enum HiveFunc { + + // mathematical functions + ROUND(2), FLOOR(1); + + // collection functions + + // date functions + + // string functions + + private int nrArgs; + private boolean isVariadic; + + HiveFunc(int nrArgs) { + this(nrArgs, false); + } + + HiveFunc(int nrArgs, boolean isVariadic) { + this.nrArgs = nrArgs; + this.isVariadic = isVariadic; + } + + public static HiveFunc getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + if (isVariadic) { + return Randomly.smallNumber() + nrArgs; + } else { + return nrArgs; + } + } + + } +} diff --git a/src/sqlancer/hive/gen/HiveInsertGenerator.java b/src/sqlancer/hive/gen/HiveInsertGenerator.java new file mode 100644 index 000000000..963fafbce --- /dev/null +++ b/src/sqlancer/hive/gen/HiveInsertGenerator.java @@ -0,0 +1,52 @@ +package sqlancer.hive.gen; + +import java.util.List; + +import sqlancer.common.gen.AbstractInsertGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.hive.HiveErrors; +import sqlancer.hive.HiveGlobalState; +import sqlancer.hive.HiveSchema.HiveColumn; +import sqlancer.hive.HiveSchema.HiveTable; +import sqlancer.hive.HiveToStringVisitor; + +public class HiveInsertGenerator extends AbstractInsertGenerator { + + private final HiveGlobalState globalState; + private final ExpectedErrors errors = new ExpectedErrors(); + private final HiveExpressionGenerator gen; + + public HiveInsertGenerator(HiveGlobalState globalState) { + this.globalState = globalState; + this.gen = new HiveExpressionGenerator(globalState); + } + + public static SQLQueryAdapter getQuery(HiveGlobalState globalState) { + return new HiveInsertGenerator(globalState).generate(); + } + + @Override + protected void insertValue(HiveColumn column) { + sb.append(HiveToStringVisitor.asString(gen.generateConstant())); + } + + private SQLQueryAdapter generate() { + // Inserting values into tables from SQL. + sb.append("INSERT INTO "); + HiveTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + + // TODO: specify the inserted partition + + sb.append(" VALUES "); + + // Values must be provided by every column in the Hive table. + // A value is either null or any valid SQL literal. + List columns = table.getColumns(); + insertColumns(columns); + + HiveErrors.addInsertErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, false, false); + } +} diff --git a/src/sqlancer/hive/gen/HiveTableGenerator.java b/src/sqlancer/hive/gen/HiveTableGenerator.java new file mode 100644 index 000000000..c1c4db2bd --- /dev/null +++ b/src/sqlancer/hive/gen/HiveTableGenerator.java @@ -0,0 +1,122 @@ +package sqlancer.hive.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.hive.HiveErrors; +import sqlancer.hive.HiveGlobalState; +import sqlancer.hive.HiveSchema; +import sqlancer.hive.HiveSchema.HiveColumn; +import sqlancer.hive.HiveSchema.HiveDataType; +import sqlancer.hive.HiveSchema.HiveTable; +import sqlancer.hive.HiveToStringVisitor; + +public class HiveTableGenerator { + + // TODO: support various file formats? e.g. JSONFILE, SEQUENCEFILE, TEXTFILE, RCFILE, ORC, PARQUET, AVRO. + + private enum ColumnConstraints { + PRIMARY_KEY_DISABLE, UNIQUE_DISABLE, NOT_NULL, DEFAULT, CHECK + // ENABLE_OR_DISABLE, NOVALIDATE, RELY_OR_NORELY + } + + private final HiveGlobalState globalState; + private final String tableName; + private final boolean allowPrimaryKey = Randomly.getBoolean(); + private final StringBuilder sb = new StringBuilder(); + private final HiveExpressionGenerator gen; + private final HiveTable table; + private final List columnsToBeAdded = new ArrayList<>(); + private boolean setPrimaryKey; + + public HiveTableGenerator(HiveGlobalState globalState, String tableName) { + this.tableName = tableName; + this.globalState = globalState; + this.table = new HiveTable(tableName, columnsToBeAdded, false); + this.gen = new HiveExpressionGenerator(globalState).setColumns(columnsToBeAdded); + } + + public static SQLQueryAdapter generate(HiveGlobalState globalState, String tableName) { + HiveTableGenerator generator = new HiveTableGenerator(globalState, tableName); + return generator.create(); + } + + private SQLQueryAdapter create() { + ExpectedErrors errors = new ExpectedErrors(); + + sb.append("CREATE TABLE "); + sb.append(globalState.getDatabaseName()); + sb.append("."); + sb.append(tableName); + sb.append(" ("); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + appendColumn(i); + } + sb.append(")"); + + // TODO: implement PARTITION BY clause + // TODO: implement CLUSTERED BY, SKEWED BY clauses + // TODO: implement ROW FORMAT and STORED AS clauses + // TODO: randomly add some predefined TABLEPROPERTIES + // TODO: implement CTAS (AS clause) + + HiveErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, true, false); + } + + private void appendColumn(int columnId) { + String columnName = DBMSCommon.createColumnName(columnId); + sb.append(columnName); + sb.append(" "); + HiveDataType randType = HiveSchema.HiveDataType.getRandomType(); + sb.append(randType); + columnsToBeAdded.add(new HiveColumn(columnName, table, randType)); + appendColumnConstraint(); + } + + private void appendColumnConstraint() { + /* + * column_constraint_specification: : [ PRIMARY KEY|UNIQUE|NOT NULL|DEFAULT [default_value]|CHECK + * [check_expression] ENABLE|DISABLE NOVALIDATE RELY/NORELY ] + */ + if (Randomly.getBoolean()) { + // no column constraint + return; + } + + ColumnConstraints constraint = Randomly.fromOptions(ColumnConstraints.values()); + switch (constraint) { + case PRIMARY_KEY_DISABLE: + if (allowPrimaryKey && !setPrimaryKey) { + sb.append(" PRIMARY KEY DISABLE"); + setPrimaryKey = true; + } + break; + case UNIQUE_DISABLE: + sb.append(" UNIQUE DISABLE"); + break; + case NOT_NULL: + sb.append(" NOT NULL"); + break; + case DEFAULT: + sb.append(" DEFAULT ("); + sb.append(HiveToStringVisitor.asString(gen.generateConstant())); + sb.append(")"); + case CHECK: + sb.append(" CHECK ("); + sb.append(HiveToStringVisitor.asString(gen.generateExpression())); + sb.append(")"); + break; + default: + throw new AssertionError(constraint); + } + } + +} diff --git a/src/sqlancer/hsqldb/HSQLDBErrors.java b/src/sqlancer/hsqldb/HSQLDBErrors.java new file mode 100644 index 000000000..33995a0e1 --- /dev/null +++ b/src/sqlancer/hsqldb/HSQLDBErrors.java @@ -0,0 +1,43 @@ +package sqlancer.hsqldb; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class HSQLDBErrors { + + private HSQLDBErrors() { + } + + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("invalid datetime format"); + errors.add("invalid character value for cast"); + errors.add("invalid ORDER BY expression"); + errors.add("data type of expression is not boolean"); + errors.add("numeric value out of range"); + errors.add("incompatible data types in combination"); + errors.add("string data, right truncation"); + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + + errors.addAll(getExpressionErrors()); + + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); + } + +} diff --git a/src/sqlancer/hsqldb/HSQLDBOptions.java b/src/sqlancer/hsqldb/HSQLDBOptions.java new file mode 100644 index 000000000..215f4f0fc --- /dev/null +++ b/src/sqlancer/hsqldb/HSQLDBOptions.java @@ -0,0 +1,21 @@ +package sqlancer.hsqldb; + +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; + +@Parameters(commandDescription = "hsqldb") +public class HSQLDBOptions implements DBMSSpecificOptions { + + @Parameter(names = "--oracle") + public List oracle = List.of(HSQLDBOracleFactory.WHERE, HSQLDBOracleFactory.NOREC); + + @Override + public List getTestOracleFactory() { + return oracle; + } + +} diff --git a/src/sqlancer/hsqldb/HSQLDBOracleFactory.java b/src/sqlancer/hsqldb/HSQLDBOracleFactory.java new file mode 100644 index 000000000..ed4fdf79c --- /dev/null +++ b/src/sqlancer/hsqldb/HSQLDBOracleFactory.java @@ -0,0 +1,32 @@ +package sqlancer.hsqldb; + +import java.sql.SQLException; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.hsqldb.gen.HSQLDBExpressionGenerator; + +public enum HSQLDBOracleFactory implements OracleFactory { + WHERE { + @Override + public TestOracle create(HSQLDBProvider.HSQLDBGlobalState globalState) + throws SQLException { + HSQLDBExpressionGenerator gen = new HSQLDBExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(HSQLDBErrors.getExpressionErrors()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + NOREC { + @Override + public TestOracle create(HSQLDBProvider.HSQLDBGlobalState globalState) + throws Exception { + HSQLDBExpressionGenerator gen = new HSQLDBExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(HSQLDBErrors.getExpressionErrors()).build(); + return new NoRECOracle<>(globalState, gen, errors); + } + } +} diff --git a/src/sqlancer/hsqldb/HSQLDBProvider.java b/src/sqlancer/hsqldb/HSQLDBProvider.java new file mode 100644 index 000000000..0490f90c7 --- /dev/null +++ b/src/sqlancer/hsqldb/HSQLDBProvider.java @@ -0,0 +1,110 @@ +package sqlancer.hsqldb; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.hsqldb.gen.HSQLDBInsertGenerator; +import sqlancer.hsqldb.gen.HSQLDBTableGenerator; +import sqlancer.hsqldb.gen.HSQLDBUpdateGenerator; + +@AutoService(DatabaseProvider.class) +public class HSQLDBProvider extends SQLProviderAdapter { + + private static final String HSQLDB = "hsqldb"; + + public HSQLDBProvider() { + super(HSQLDBGlobalState.class, HSQLDBOptions.class); + } + + public enum Action implements AbstractAction { + INSERT(HSQLDBInsertGenerator::getQuery), UPDATE(HSQLDBUpdateGenerator::getQuery); + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(HSQLDBProvider.HSQLDBGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + + @Override + public SQLConnection createDatabase(HSQLDBGlobalState globalState) throws Exception { + String databaseName = globalState.getDatabaseName(); + String url = "jdbc:hsqldb:file:" + databaseName; + MainOptions options = globalState.getOptions(); + Connection connection = DriverManager.getConnection(url, options.getUserName(), options.getPassword()); + // When a server instance is started, or when a connection is made to an in-process database, + // a new, empty database is created if no database exists at the given path. + try (Statement s = connection.createStatement()) { + s.execute("DROP SCHEMA PUBLIC CASCADE"); + s.execute("SET DATABASE SQL DOUBLE NAN FALSE"); + } + return new SQLConnection(connection); + } + + @Override + public String getDBMSName() { + return HSQLDB; + } + + @Override + public void generateDatabase(HSQLDBGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { + boolean success; + do { + SQLQueryAdapter qt = new HSQLDBTableGenerator().getQuery(globalState, null); + success = globalState.executeStatement(qt); + } while (!success); + } + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + StatementExecutor se = new StatementExecutor<>(globalState, + HSQLDBProvider.Action.values(), HSQLDBProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + private static int mapActions(HSQLDBProvider.HSQLDBGlobalState globalState, HSQLDBProvider.Action a) { + Randomly r = globalState.getRandomly(); + switch (a) { + case INSERT: + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + case UPDATE: + return r.getInteger(0, 10); + default: + throw new AssertionError(a); + } + } + + public static class HSQLDBGlobalState extends SQLGlobalState { + + @Override + protected HSQLDBSchema readSchema() throws SQLException { + return HSQLDBSchema.fromConnection(getConnection(), getDatabaseName()); + } + + } +} diff --git a/src/sqlancer/hsqldb/HSQLDBSchema.java b/src/sqlancer/hsqldb/HSQLDBSchema.java new file mode 100644 index 000000000..2d41df83f --- /dev/null +++ b/src/sqlancer/hsqldb/HSQLDBSchema.java @@ -0,0 +1,159 @@ +package sqlancer.hsqldb; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.DBMSCommon; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.TableIndex; + +public class HSQLDBSchema extends AbstractSchema { + + public HSQLDBSchema(List databaseTables) { + super(databaseTables); + } + + public static HSQLDBSchema fromConnection(SQLConnection connection, String databaseName) throws SQLException { + List databaseTables = new ArrayList<>(); + List tableNames = getTableNames(connection); + for (String tableName : tableNames) { + if (DBMSCommon.matchesIndexName(tableName)) { + continue; // TODO: unexpected? + } + List databaseColumns = getTableColumns(connection, tableName); + boolean isView = matchesViewName(tableName); + HSQLDBSchema.HSQLDBTable t = new HSQLDBSchema.HSQLDBTable(tableName, databaseColumns, isView); + for (HSQLDBSchema.HSQLDBColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + + } + return new HSQLDBSchema(databaseTables); + } + + private static List getTableNames(SQLConnection con) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s + .executeQuery("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'PUBLIC';")) { + while (rs.next()) { + tableNames.add(rs.getString("TABLE_NAME")); + } + } + } + return tableNames; + } + + private static List getTableColumns(SQLConnection con, String tableName) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + String sql = "SELECT COLUMN_NAME, DATA_TYPE, TYPE_NAME, COLUMN_SIZE FROM INFORMATION_SCHEMA.SYSTEM_COLUMNS WHERE TABLE_NAME = '%s';"; + try (ResultSet rs = s.executeQuery(String.format(sql, tableName))) { + while (rs.next()) { + HSQLDBDataType dataType = HSQLDBDataType.from(rs.getString("TYPE_NAME")); + HSQLDBCompositeDataType compositeDataType = new HSQLDBCompositeDataType(dataType, + rs.getInt("COLUMN_SIZE")); + HSQLDBColumn column = new HSQLDBColumn(rs.getString("COLUMN_NAME"), null, compositeDataType); + tableNames.add(column); + } + } + } + return tableNames; + } + + public static class HSQLDBTable + extends AbstractRelationalTable { + + public HSQLDBTable(String tableName, List columns, boolean isView) { + super(tableName, columns, Collections.emptyList(), isView); + } + + } + + public static class HSQLDBColumn + extends AbstractTableColumn { + + public HSQLDBColumn(String name, HSQLDBTable table, HSQLDBCompositeDataType type) { + super(name, table, type); + } + } + + public enum HSQLDBDataType { + + INTEGER, DOUBLE, BOOLEAN, CHAR, VARCHAR, BINARY, TIME, DATE, TIMESTAMP, NULL; + + public static HSQLDBSchema.HSQLDBDataType getRandomWithoutNull() { + HSQLDBSchema.HSQLDBDataType dt; + do { + dt = Randomly.fromOptions(values()); + } while (dt == HSQLDBSchema.HSQLDBDataType.NULL); + return dt; + } + + public static HSQLDBDataType from(String typeName) { + for (HSQLDBDataType value : HSQLDBDataType.values()) { + if (value.name().equals(typeName)) { + return value; + } + } + return NULL; + } + } + + public static class HSQLDBCompositeDataType { + private final int size; + private final HSQLDBDataType type; + + public HSQLDBCompositeDataType(HSQLDBDataType type, int size) { + this.type = type; + this.size = size; + } + + public static HSQLDBCompositeDataType getRandomWithoutNull() { + HSQLDBSchema.HSQLDBDataType type = HSQLDBSchema.HSQLDBDataType.getRandomWithoutNull(); + return getRandomWithType(type); + } + + public static HSQLDBCompositeDataType getRandomWithType(HSQLDBSchema.HSQLDBDataType type) { + int size; + switch (type) { + case VARCHAR: + case CHAR: + case TIME: + case BINARY: + case TIMESTAMP: + size = Randomly.fromOptions(4, 6, 8); + break; + case BOOLEAN: + case INTEGER: + case DOUBLE: + // case UUID: + // case OTHER: + case DATE: + size = 0; + break; + default: + throw new AssertionError(type); + } + + return new HSQLDBSchema.HSQLDBCompositeDataType(type, size); + } + + public HSQLDBDataType getType() { + return type; + } + + public int getSize() { + return size; + } + } +} diff --git a/src/sqlancer/hsqldb/HSQLDBToStringVisitor.java b/src/sqlancer/hsqldb/HSQLDBToStringVisitor.java new file mode 100644 index 000000000..99fc50687 --- /dev/null +++ b/src/sqlancer/hsqldb/HSQLDBToStringVisitor.java @@ -0,0 +1,89 @@ +package sqlancer.hsqldb; + +import sqlancer.common.ast.newast.NewToStringVisitor; +import sqlancer.hsqldb.ast.HSQLDBConstant; +import sqlancer.hsqldb.ast.HSQLDBExpression; +import sqlancer.hsqldb.ast.HSQLDBJoin; +import sqlancer.hsqldb.ast.HSQLDBSelect; + +public class HSQLDBToStringVisitor extends NewToStringVisitor { + + @Override + public void visitSpecific(HSQLDBExpression expr) { + if (expr instanceof HSQLDBConstant) { + visit((HSQLDBConstant) expr); + } else if (expr instanceof HSQLDBSelect) { + visit((HSQLDBSelect) expr); + } else if (expr instanceof HSQLDBJoin) { + visit((HSQLDBJoin) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + public static String asString(HSQLDBExpression expr) { + HSQLDBToStringVisitor visitor = new HSQLDBToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } + + private void visit(HSQLDBJoin join) { + visit((HSQLDBExpression) join.getLeftTable()); + sb.append(" "); + sb.append(join.getJoinType()); + sb.append(" "); + if (join.getOuterType() != null) { + sb.append(join.getOuterType()); + } + sb.append(" JOIN "); + visit((HSQLDBExpression) join.getRightTable()); + if (join.getOnCondition() != null) { + sb.append(" ON "); + visit(join.getOnCondition()); + } + } + + private void visit(HSQLDBConstant constant) { + sb.append(constant.toString()); + } + + private void visit(HSQLDBSelect select) { + sb.append("SELECT "); + if (select.isDistinct()) { + sb.append("DISTINCT "); + } + visit(select.getFetchColumns()); + sb.append(" FROM "); + visit(select.getFromList()); + if (!select.getFromList().isEmpty() && !select.getJoinList().isEmpty()) { + sb.append(", "); + } + if (!select.getJoinList().isEmpty()) { + visit(select.getJoinList()); + } + if (select.getWhereClause() != null) { + sb.append(" WHERE "); + visit(select.getWhereClause()); + } + if (!select.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(select.getGroupByExpressions()); + } + if (select.getHavingClause() != null) { + sb.append(" HAVING "); + visit(select.getHavingClause()); + } + if (!select.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(select.getOrderByClauses()); + } + if (select.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(select.getLimitClause()); + } + if (select.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(select.getOffsetClause()); + } + } +} diff --git a/src/sqlancer/hsqldb/ast/HSQLDBBinaryOperation.java b/src/sqlancer/hsqldb/ast/HSQLDBBinaryOperation.java new file mode 100644 index 000000000..b821a90e2 --- /dev/null +++ b/src/sqlancer/hsqldb/ast/HSQLDBBinaryOperation.java @@ -0,0 +1,10 @@ +package sqlancer.hsqldb.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class HSQLDBBinaryOperation extends NewBinaryOperatorNode implements HSQLDBExpression { + public HSQLDBBinaryOperation(HSQLDBExpression left, HSQLDBExpression right, Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/hsqldb/ast/HSQLDBColumnReference.java b/src/sqlancer/hsqldb/ast/HSQLDBColumnReference.java new file mode 100644 index 000000000..cb579abac --- /dev/null +++ b/src/sqlancer/hsqldb/ast/HSQLDBColumnReference.java @@ -0,0 +1,12 @@ +package sqlancer.hsqldb.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.hsqldb.HSQLDBSchema; + +public class HSQLDBColumnReference extends ColumnReferenceNode + implements HSQLDBExpression { + + public HSQLDBColumnReference(HSQLDBSchema.HSQLDBColumn column) { + super(column); + } +} diff --git a/src/sqlancer/hsqldb/ast/HSQLDBConstant.java b/src/sqlancer/hsqldb/ast/HSQLDBConstant.java new file mode 100644 index 000000000..67997d6bb --- /dev/null +++ b/src/sqlancer/hsqldb/ast/HSQLDBConstant.java @@ -0,0 +1,227 @@ +package sqlancer.hsqldb.ast; + +import java.sql.Timestamp; +import java.text.SimpleDateFormat; + +public class HSQLDBConstant implements HSQLDBExpression { + + private HSQLDBConstant() { + } + + public static class HSQLDBNullConstant extends HSQLDBConstant { + + @Override + public String toString() { + return "Null"; + } + + } + + public static class HSQLDBIntConstant extends HSQLDBConstant { + + private final int value; + + public HSQLDBIntConstant(long value) { + this.value = (int) value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + public long getValue() { + return value; + } + + } + + public static class HSQLDBDoubleConstant extends HSQLDBConstant { + + private final double value; + + public HSQLDBDoubleConstant(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "1.0e1/0.0e1"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "-1.0e1/0.0e1"; + } + return String.valueOf(value); + } + + } + + public static class HSQLDBTextConstant extends HSQLDBConstant { + + private final String value; + + public HSQLDBTextConstant(String value) { + this.value = value; + } + + public HSQLDBTextConstant(String value, int size) { + this.value = value.substring(0, Math.min(value.length(), size)); + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "'" + value.replace("'", "''") + "'"; + } + + } + + public static class HSQLDBBinaryConstant extends HSQLDBConstant { + + private final String value; + + public HSQLDBBinaryConstant(long value, int size) { + StringBuilder hex = new StringBuilder(Long.toHexString(value)); + if (hex.length() < 2) { + hex.append('0'); + } + this.value = hex.substring(0, Math.min(hex.length(), size)); + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "X'" + value + "'"; + } + + } + + public static class HSQLDBDateConstant extends HSQLDBConstant { + + public String textRepr; + + public HSQLDBDateConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("DATE '%s'", textRepr); + } + + } + + public static class HSQLDBTimestampConstant extends HSQLDBConstant { + + public String textRepr; + + public HSQLDBTimestampConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("TIMESTAMP '%s'", textRepr); + } + + } + + public static class HSQLDBTimeConstant extends HSQLDBConstant { + + public String textRepr; + + public HSQLDBTimeConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("HH:mm:ss"); + textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("TIME '%s'", textRepr); + } + + } + + public static class HSQLDBBooleanConstant extends HSQLDBConstant { + + private final boolean value; + + public HSQLDBBooleanConstant(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + } + + public static HSQLDBExpression createStringConstant(String text, int size) { + return new HSQLDBTextConstant(text, size); + } + + public static HSQLDBExpression createFloatConstant(double val) { + return new HSQLDBDoubleConstant(val); + } + + public static HSQLDBExpression createIntConstant(long val) { + return new HSQLDBIntConstant(val); + } + + public static HSQLDBExpression createNullConstant() { + return new HSQLDBNullConstant(); + } + + public static HSQLDBExpression createBooleanConstant(boolean val) { + return new HSQLDBBooleanConstant(val); + } + + public static HSQLDBExpression createDateConstant(long integer) { + return new HSQLDBDateConstant(integer); + } + + public static HSQLDBExpression createTimeConstant(long integer, int size) { + return new HSQLDBTimeConstant(integer); + } + + public static HSQLDBExpression createTimestampConstant(long integer, int size) { + return new HSQLDBTimestampConstant(integer); + } + + public static HSQLDBExpression createBinaryConstant(long nonCachedInteger, int size) { + return new HSQLDBBinaryConstant(nonCachedInteger, size); + } + +} diff --git a/src/sqlancer/hsqldb/ast/HSQLDBExpression.java b/src/sqlancer/hsqldb/ast/HSQLDBExpression.java new file mode 100644 index 000000000..d066b4359 --- /dev/null +++ b/src/sqlancer/hsqldb/ast/HSQLDBExpression.java @@ -0,0 +1,7 @@ +package sqlancer.hsqldb.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBColumn; + +public interface HSQLDBExpression extends Expression { +} diff --git a/src/sqlancer/hsqldb/ast/HSQLDBJoin.java b/src/sqlancer/hsqldb/ast/HSQLDBJoin.java new file mode 100644 index 000000000..a895e1b20 --- /dev/null +++ b/src/sqlancer/hsqldb/ast/HSQLDBJoin.java @@ -0,0 +1,91 @@ +package sqlancer.hsqldb.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBColumn; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBTable; + +public class HSQLDBJoin implements HSQLDBExpression, Join { + + private final HSQLDBTableReference leftTable; + private final HSQLDBTableReference rightTable; + private final JoinType joinType; + private HSQLDBExpression onCondition; + private OuterType outerType; + + public enum JoinType { + INNER, NATURAL, LEFT, RIGHT; + + public static JoinType getRandom() { + return Randomly.fromOptions(values()); + } + } + + public enum OuterType { + FULL, LEFT, RIGHT; + + public static OuterType getRandom() { + return Randomly.fromOptions(values()); + } + } + + public HSQLDBJoin(HSQLDBTableReference leftTable, HSQLDBTableReference rightTable, JoinType joinType, + HSQLDBExpression whereCondition) { + this.leftTable = leftTable; + this.rightTable = rightTable; + this.joinType = joinType; + this.onCondition = whereCondition; + } + + public HSQLDBTableReference getLeftTable() { + return leftTable; + } + + public HSQLDBTableReference getRightTable() { + return rightTable; + } + + public JoinType getJoinType() { + return joinType; + } + + public HSQLDBExpression getOnCondition() { + return onCondition; + } + + private void setOuterType(OuterType outerType) { + this.outerType = outerType; + } + + public OuterType getOuterType() { + return outerType; + } + + public static HSQLDBJoin createRightOuterJoin(HSQLDBTableReference left, HSQLDBTableReference right, + HSQLDBExpression predicate) { + return new HSQLDBJoin(left, right, JoinType.RIGHT, predicate); + } + + public static HSQLDBJoin createLeftOuterJoin(HSQLDBTableReference left, HSQLDBTableReference right, + HSQLDBExpression predicate) { + return new HSQLDBJoin(left, right, JoinType.LEFT, predicate); + } + + public static HSQLDBJoin createInnerJoin(HSQLDBTableReference left, HSQLDBTableReference right, + HSQLDBExpression predicate) { + return new HSQLDBJoin(left, right, JoinType.INNER, predicate); + } + + public static HSQLDBJoin createNaturalJoin(HSQLDBTableReference left, HSQLDBTableReference right, + OuterType naturalJoinType) { + HSQLDBJoin join = new HSQLDBJoin(left, right, JoinType.NATURAL, null); + join.setOuterType(naturalJoinType); + return join; + } + + @Override + public void setOnClause(HSQLDBExpression onClause) { + onCondition = onClause; + } + +} diff --git a/src/sqlancer/hsqldb/ast/HSQLDBSelect.java b/src/sqlancer/hsqldb/ast/HSQLDBSelect.java new file mode 100644 index 000000000..d58275d1e --- /dev/null +++ b/src/sqlancer/hsqldb/ast/HSQLDBSelect.java @@ -0,0 +1,41 @@ +package sqlancer.hsqldb.ast; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBColumn; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBTable; +import sqlancer.hsqldb.HSQLDBToStringVisitor; + +public class HSQLDBSelect extends SelectBase + implements HSQLDBExpression, Select { + + private boolean isDistinct; + + public void setDistinct(boolean isDistinct) { + this.isDistinct = isDistinct; + } + + public boolean isDistinct() { + return isDistinct; + } + + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (HSQLDBExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (HSQLDBJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return HSQLDBToStringVisitor.asString(this); + } +} diff --git a/src/sqlancer/hsqldb/ast/HSQLDBTableReference.java b/src/sqlancer/hsqldb/ast/HSQLDBTableReference.java new file mode 100644 index 000000000..010a921a7 --- /dev/null +++ b/src/sqlancer/hsqldb/ast/HSQLDBTableReference.java @@ -0,0 +1,11 @@ +package sqlancer.hsqldb.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.hsqldb.HSQLDBSchema; + +public class HSQLDBTableReference extends TableReferenceNode + implements HSQLDBExpression { + public HSQLDBTableReference(HSQLDBSchema.HSQLDBTable table) { + super(table); + } +} diff --git a/src/sqlancer/hsqldb/ast/HSQLDBUnaryPostfixOperation.java b/src/sqlancer/hsqldb/ast/HSQLDBUnaryPostfixOperation.java new file mode 100644 index 000000000..5ebca87fa --- /dev/null +++ b/src/sqlancer/hsqldb/ast/HSQLDBUnaryPostfixOperation.java @@ -0,0 +1,57 @@ +package sqlancer.hsqldb.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; +import sqlancer.hsqldb.HSQLDBSchema; + +public class HSQLDBUnaryPostfixOperation extends NewUnaryPostfixOperatorNode + implements HSQLDBExpression { + + public HSQLDBUnaryPostfixOperation(HSQLDBExpression expr, HSQLDBUnaryPostfixOperator op) { + super(expr, op); + } + + public enum HSQLDBUnaryPostfixOperator implements BinaryOperatorNode.Operator { + IS_NULL("IS NULL") { + @Override + public HSQLDBSchema.HSQLDBDataType[] getInputDataTypes() { + return HSQLDBSchema.HSQLDBDataType.values(); + } + }, + IS_NOT_NULL("IS NOT NULL") { + @Override + public HSQLDBSchema.HSQLDBDataType[] getInputDataTypes() { + return HSQLDBSchema.HSQLDBDataType.values(); + } + }; + + private final String textRepresentations; + + HSQLDBUnaryPostfixOperator(String text) { + this.textRepresentations = text; + } + + public static HSQLDBUnaryPostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentations; + } + + public abstract HSQLDBSchema.HSQLDBDataType[] getInputDataTypes(); + + } + + public HSQLDBExpression getExpression() { + return getExpr(); + } + + @Override + public String getOperatorRepresentation() { + return this.op.getTextRepresentation(); + } + +} diff --git a/src/sqlancer/hsqldb/ast/HSQLDBUnaryPrefixOperation.java b/src/sqlancer/hsqldb/ast/HSQLDBUnaryPrefixOperation.java new file mode 100644 index 000000000..44b2beece --- /dev/null +++ b/src/sqlancer/hsqldb/ast/HSQLDBUnaryPrefixOperation.java @@ -0,0 +1,54 @@ +package sqlancer.hsqldb.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; +import sqlancer.hsqldb.HSQLDBSchema; + +public class HSQLDBUnaryPrefixOperation extends NewUnaryPrefixOperatorNode + implements HSQLDBExpression { + + public HSQLDBUnaryPrefixOperation(HSQLDBUnaryPrefixOperator operation, HSQLDBExpression expression) { + super(expression, operation); + } + + @Override + public String getOperatorRepresentation() { + return this.op.getTextRepresentation(); + } + + public enum HSQLDBUnaryPrefixOperator implements BinaryOperatorNode.Operator { + NOT("NOT", HSQLDBSchema.HSQLDBDataType.BOOLEAN, HSQLDBSchema.HSQLDBDataType.INTEGER) { + @Override + public HSQLDBSchema.HSQLDBDataType getExpressionType() { + return HSQLDBSchema.HSQLDBDataType.BOOLEAN; + } + }, + + UNARY_PLUS("+", HSQLDBSchema.HSQLDBDataType.INTEGER) { + @Override + public HSQLDBSchema.HSQLDBDataType getExpressionType() { + return HSQLDBSchema.HSQLDBDataType.INTEGER; + } + }, + UNARY_MINUS("-", HSQLDBSchema.HSQLDBDataType.INTEGER) { + @Override + public HSQLDBSchema.HSQLDBDataType getExpressionType() { + return HSQLDBSchema.HSQLDBDataType.INTEGER; + } + }; + + private String textRepresentation; + + HSQLDBUnaryPrefixOperator(String textRepresentation, HSQLDBSchema.HSQLDBDataType... dataTypes) { + this.textRepresentation = textRepresentation; + } + + public abstract HSQLDBSchema.HSQLDBDataType getExpressionType(); + + @Override + public String getTextRepresentation() { + return this.textRepresentation; + } + } + +} diff --git a/src/sqlancer/hsqldb/gen/HSQLDBExpressionGenerator.java b/src/sqlancer/hsqldb/gen/HSQLDBExpressionGenerator.java new file mode 100644 index 000000000..be73d1cb1 --- /dev/null +++ b/src/sqlancer/hsqldb/gen/HSQLDBExpressionGenerator.java @@ -0,0 +1,302 @@ +package sqlancer.hsqldb.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.gen.TypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; +import sqlancer.hsqldb.HSQLDBProvider; +import sqlancer.hsqldb.HSQLDBSchema; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBColumn; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBTable; +import sqlancer.hsqldb.ast.HSQLDBBinaryOperation; +import sqlancer.hsqldb.ast.HSQLDBColumnReference; +import sqlancer.hsqldb.ast.HSQLDBConstant; +import sqlancer.hsqldb.ast.HSQLDBExpression; +import sqlancer.hsqldb.ast.HSQLDBJoin; +import sqlancer.hsqldb.ast.HSQLDBSelect; +import sqlancer.hsqldb.ast.HSQLDBTableReference; +import sqlancer.hsqldb.ast.HSQLDBUnaryPostfixOperation; +import sqlancer.hsqldb.ast.HSQLDBUnaryPrefixOperation; + +public final class HSQLDBExpressionGenerator extends + TypedExpressionGenerator + implements NoRECGenerator, + TLPWhereGenerator { + + List tables; + + private enum Expression { + BINARY_LOGICAL, BINARY_COMPARISON, BINARY_ARITHMETIC; + } + + HSQLDBProvider.HSQLDBGlobalState hsqldbGlobalState; + + public HSQLDBExpressionGenerator(HSQLDBProvider.HSQLDBGlobalState globalState) { + this.hsqldbGlobalState = globalState; + } + + @Override + public HSQLDBExpression generatePredicate() { + return generateExpression( + HSQLDBSchema.HSQLDBCompositeDataType.getRandomWithType(HSQLDBSchema.HSQLDBDataType.BOOLEAN)); + } + + @Override + public HSQLDBExpression negatePredicate(HSQLDBExpression predicate) { + return new HSQLDBUnaryPrefixOperation(HSQLDBUnaryPrefixOperation.HSQLDBUnaryPrefixOperator.NOT, predicate); + } + + @Override + public HSQLDBExpression isNull(HSQLDBExpression expr) { + return new HSQLDBUnaryPostfixOperation(expr, HSQLDBUnaryPostfixOperation.HSQLDBUnaryPostfixOperator.IS_NULL); + } + + @Override + public HSQLDBExpression generateConstant(HSQLDBSchema.HSQLDBCompositeDataType type) { + switch (type.getType()) { + case NULL: + return HSQLDBConstant.createNullConstant(); + case CHAR: + return HSQLDBConstant.HSQLDBTextConstant + .createStringConstant(hsqldbGlobalState.getRandomly().getAlphabeticChar(), type.getSize()); + case VARCHAR: + return HSQLDBConstant.HSQLDBTextConstant.createStringConstant(hsqldbGlobalState.getRandomly().getString(), + type.getSize()); + case TIME: + return HSQLDBConstant.createTimeConstant( + hsqldbGlobalState.getRandomly().getLong(0, System.currentTimeMillis()), type.getSize()); + case TIMESTAMP: + return HSQLDBConstant.createTimestampConstant( + hsqldbGlobalState.getRandomly().getLong(0, System.currentTimeMillis()), type.getSize()); + + case INTEGER: + return HSQLDBConstant.HSQLDBIntConstant.createIntConstant(Randomly.getNonCachedInteger()); + case DOUBLE: + return HSQLDBConstant.HSQLDBDoubleConstant.createFloatConstant(hsqldbGlobalState.getRandomly().getDouble()); + case BOOLEAN: + return HSQLDBConstant.HSQLDBBooleanConstant.createBooleanConstant(Randomly.getBoolean()); + case DATE: + return HSQLDBConstant + .createDateConstant(hsqldbGlobalState.getRandomly().getLong(0, System.currentTimeMillis())); + case BINARY: + return HSQLDBConstant.createBinaryConstant(Randomly.getNonCachedInteger(), type.getSize()); + default: + throw new AssertionError("Unknown type: " + type); + } + } + + @Override + protected HSQLDBExpression generateExpression(HSQLDBSchema.HSQLDBCompositeDataType type, int depth) { + if (depth >= hsqldbGlobalState.getOptions().getMaxExpressionDepth() + || Randomly.getBooleanWithSmallProbability()) { + return generateLeafNode(type); + } + + List possibleOptions = new ArrayList<>( + Arrays.asList(HSQLDBExpressionGenerator.Expression.values())); + + HSQLDBExpressionGenerator.Expression expr = Randomly.fromList(possibleOptions); + BinaryOperatorNode.Operator op; + switch (expr) { + case BINARY_LOGICAL: + case BINARY_ARITHMETIC: + op = HSQLDBExpressionGenerator.HSQLDBBinaryLogicalOperator.getRandom(); + break; + case BINARY_COMPARISON: + op = HSQLDBDBBinaryComparisonOperator.getRandom(); + break; + default: + throw new AssertionError(); + } + + return new HSQLDBBinaryOperation(generateExpression(type, depth + 1), generateExpression(type, depth + 1), op); + + } + + @Override + protected HSQLDBExpression generateColumn(HSQLDBSchema.HSQLDBCompositeDataType type) { + HSQLDBSchema.HSQLDBColumn column = Randomly + .fromList(columns.stream().filter(c -> c.getType() == type).collect(Collectors.toList())); + return new HSQLDBColumnReference(column); + } + + @Override + protected HSQLDBSchema.HSQLDBCompositeDataType getRandomType() { + return HSQLDBSchema.HSQLDBCompositeDataType.getRandomWithoutNull(); + } + + @Override + protected boolean canGenerateColumnOfType(HSQLDBSchema.HSQLDBCompositeDataType type) { + return columns.stream().anyMatch(c -> c.getType() == type); + } + + public enum HSQLDBBinaryLogicalOperator implements BinaryOperatorNode.Operator { + + AND, OR; + + @Override + public String getTextRepresentation() { + return toString(); + } + + public static BinaryOperatorNode.Operator getRandom() { + return Randomly.fromOptions(values()); + } + + } + + public enum HSQLDBDBBinaryComparisonOperator implements BinaryOperatorNode.Operator { + EQUALS("="), GREATER(">"), GREATER_EQUALS(">="), SMALLER("<"), SMALLER_EQUALS("<="), NOT_EQUALS("!="); + + private String textRepr; + + HSQLDBDBBinaryComparisonOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static BinaryOperatorNode.Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + + } + + public enum HSQLDBDBBinaryArithmeticOperator implements BinaryOperatorNode.Operator { + CONCAT("||"), ADD("+"), SUB("-"), MULT("*"), DIV("/"), MOD("%"), AND("&"), OR("|"), LSHIFT("<<"), RSHIFT(">>"); + + private String textRepr; + + HSQLDBDBBinaryArithmeticOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static BinaryOperatorNode.Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + + } + + @Override + public List generateOrderBys() { + List expressions = new ArrayList<>(); + int nr = Randomly.smallNumber() + 1; + ArrayList hsqldbColumns = new ArrayList<>(columns); + for (int i = 0; i < nr && !hsqldbColumns.isEmpty(); i++) { + HSQLDBSchema.HSQLDBColumn randomColumn = Randomly.fromList(hsqldbColumns); + HSQLDBColumnReference columnReference = new HSQLDBColumnReference(randomColumn); + hsqldbColumns.remove(randomColumn); + expressions.add(columnReference); + } + return expressions; + } + + @Override + public HSQLDBExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public HSQLDBExpression generateBooleanExpression() { + return generatePredicate(); + } + + @Override + public HSQLDBSelect generateSelect() { + return new HSQLDBSelect(); + } + + @Override + public List getRandomJoinClauses() { + List joinExpressions = new ArrayList<>(); + while (tables.size() >= 2 && Randomly.getBooleanWithRatherLowProbability()) { + HSQLDBTable leftTable = tables.remove(0); + HSQLDBTable rightTable = tables.remove(0); + List columns = new ArrayList<>(leftTable.getColumns()); + columns.addAll(rightTable.getColumns()); + HSQLDBExpressionGenerator joinGen = new HSQLDBExpressionGenerator(hsqldbGlobalState).setColumns(columns); + HSQLDBTableReference leftTableRef = new HSQLDBTableReference(leftTable); + HSQLDBTableReference rightTableRef = new HSQLDBTableReference(rightTable); + switch (HSQLDBJoin.JoinType.getRandom()) { + case INNER: + joinExpressions.add(HSQLDBJoin.createInnerJoin(leftTableRef, rightTableRef, + joinGen.generateExpression(HSQLDBSchema.HSQLDBCompositeDataType.getRandomWithoutNull()))); + break; + case NATURAL: + joinExpressions.add( + HSQLDBJoin.createNaturalJoin(leftTableRef, rightTableRef, HSQLDBJoin.OuterType.getRandom())); + break; + case LEFT: + joinExpressions.add(HSQLDBJoin.createLeftOuterJoin(leftTableRef, rightTableRef, + joinGen.generateExpression(HSQLDBSchema.HSQLDBCompositeDataType.getRandomWithoutNull()))); + break; + case RIGHT: + joinExpressions.add(HSQLDBJoin.createRightOuterJoin(leftTableRef, rightTableRef, + joinGen.generateExpression(HSQLDBSchema.HSQLDBCompositeDataType.getRandomWithoutNull()))); + break; + default: + throw new AssertionError(); + } + } + return joinExpressions; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new HSQLDBTableReference(t)).collect(Collectors.toList()); + } + + @Override + public String generateOptimizedQueryString(HSQLDBSelect select, HSQLDBExpression whereCondition, + boolean shouldUseAggregate) { + if (shouldUseAggregate) { + HSQLDBColumn aggr = new HSQLDBColumn("COUNT(*)", null, null); + select.setFetchColumns(List.of(new HSQLDBColumnReference(aggr))); + } else { + List allColumns = columns.stream().map((c) -> new HSQLDBColumnReference(c)) + .collect(Collectors.toList()); + select.setFetchColumns(allColumns); + if (Randomly.getBooleanWithSmallProbability()) { + select.setOrderByClauses(generateOrderBys()); + } + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(HSQLDBSelect select, HSQLDBExpression whereCondition) { + HSQLDBColumn c = new HSQLDBColumn("COUNT(*) as count", null, null); + select.setFetchColumns(List.of(new HSQLDBColumnReference(c))); + select.setWhereClause(null); + return "SELECT SUM(count) FROM (" + select.asString() + ") as res"; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy) { + return List.of(new HSQLDBColumnReference(new HSQLDBSchema.HSQLDBColumn("*", null, null))); + } + return Randomly + .nonEmptySubset(columns.stream().map(c -> new HSQLDBColumnReference(c)).collect(Collectors.toList())); + } +} diff --git a/src/sqlancer/hsqldb/gen/HSQLDBInsertGenerator.java b/src/sqlancer/hsqldb/gen/HSQLDBInsertGenerator.java new file mode 100644 index 000000000..1cc132190 --- /dev/null +++ b/src/sqlancer/hsqldb/gen/HSQLDBInsertGenerator.java @@ -0,0 +1,48 @@ +package sqlancer.hsqldb.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.common.gen.AbstractInsertGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.hsqldb.HSQLDBProvider; +import sqlancer.hsqldb.HSQLDBSchema; +import sqlancer.hsqldb.HSQLDBToStringVisitor; +import sqlancer.hsqldb.ast.HSQLDBExpression; + +public class HSQLDBInsertGenerator extends AbstractInsertGenerator { + + private final HSQLDBProvider.HSQLDBGlobalState globalState; + private final ExpectedErrors errors = new ExpectedErrors(); + + public HSQLDBInsertGenerator(HSQLDBProvider.HSQLDBGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter getQuery(HSQLDBProvider.HSQLDBGlobalState globalState) { + return new HSQLDBInsertGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + sb.append("INSERT INTO "); + HSQLDBSchema.HSQLDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyColumnSubset(); + sb.append(table.getName()); + sb.append("("); + sb.append(columns.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + sb.append(")"); + sb.append(" VALUES "); + insertColumns(columns); + // HSQLDBErrors.addInsertErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } + + @Override + protected void insertValue(HSQLDBSchema.HSQLDBColumn column) { + HSQLDBExpression expression = new HSQLDBExpressionGenerator(globalState).generateConstant(column.getType()); + String s = HSQLDBToStringVisitor.asString(expression); + sb.append(s); + } + +} diff --git a/src/sqlancer/hsqldb/gen/HSQLDBTableGenerator.java b/src/sqlancer/hsqldb/gen/HSQLDBTableGenerator.java new file mode 100644 index 000000000..48606e9bf --- /dev/null +++ b/src/sqlancer/hsqldb/gen/HSQLDBTableGenerator.java @@ -0,0 +1,58 @@ +package sqlancer.hsqldb.gen; + +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.hsqldb.HSQLDBProvider; +import sqlancer.hsqldb.HSQLDBSchema; + +public class HSQLDBTableGenerator { + + public SQLQueryAdapter getQuery(HSQLDBProvider.HSQLDBGlobalState globalState, @Nullable String tableName) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + String name = tableName; + if (tableName == null) { + name = globalState.getSchema().getFreeTableName(); + } + sb.append("CREATE TABLE "); + if (Randomly.getBoolean()) { + sb.append("IF NOT EXISTS "); + } + sb.append(name); + sb.append("("); + List columns = getNewColumns(); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(columns.get(i).getName()); + sb.append(" "); + sb.append(columns.get(i).getType().getType().name()); + if (columns.get(i).getType().getSize() > 0) { + // Cannot specify size for non composite data types + sb.append("("); + sb.append(columns.get(i).getType().getSize()); + sb.append(")"); + } + } + sb.append(")"); + sb.append(";"); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + + private static List getNewColumns() { + List columns = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + String columnName = String.format("c%d", i); + HSQLDBSchema.HSQLDBCompositeDataType columnType = HSQLDBSchema.HSQLDBCompositeDataType + .getRandomWithoutNull(); + columns.add(new HSQLDBSchema.HSQLDBColumn(columnName, null, columnType)); + } + return columns; + } +} diff --git a/src/sqlancer/hsqldb/gen/HSQLDBUpdateGenerator.java b/src/sqlancer/hsqldb/gen/HSQLDBUpdateGenerator.java new file mode 100644 index 000000000..e639e21b3 --- /dev/null +++ b/src/sqlancer/hsqldb/gen/HSQLDBUpdateGenerator.java @@ -0,0 +1,55 @@ +package sqlancer.hsqldb.gen; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.gen.AbstractUpdateGenerator; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.hsqldb.HSQLDBErrors; +import sqlancer.hsqldb.HSQLDBProvider; +import sqlancer.hsqldb.HSQLDBSchema; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBColumn; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBCompositeDataType; +import sqlancer.hsqldb.HSQLDBSchema.HSQLDBDataType; +import sqlancer.hsqldb.HSQLDBToStringVisitor; +import sqlancer.hsqldb.ast.HSQLDBExpression; + +public final class HSQLDBUpdateGenerator extends AbstractUpdateGenerator { + + private final HSQLDBProvider.HSQLDBGlobalState globalState; + private HSQLDBExpressionGenerator gen; + + private HSQLDBUpdateGenerator(HSQLDBProvider.HSQLDBGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter getQuery(HSQLDBProvider.HSQLDBGlobalState globalState) { + return new HSQLDBUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + HSQLDBSchema.HSQLDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyColumnSubset(); + gen = new HSQLDBExpressionGenerator(globalState).setColumns(table.getColumns()); + sb.append("UPDATE "); + sb.append(table.getName()); + sb.append(" SET "); + updateColumns(columns); + if (Randomly.getBooleanWithSmallProbability()) { + sb.append(" WHERE "); + sb.append(HSQLDBToStringVisitor.asString( + gen.generateExpression(HSQLDBCompositeDataType.getRandomWithType(HSQLDBDataType.BOOLEAN)))); + errors.add("data type of expression is not boolean"); + HSQLDBErrors.addExpressionErrors(errors); + } + return new SQLQueryAdapter(sb.toString(), errors); + } + + @Override + protected void updateValue(HSQLDBColumn column) { + HSQLDBExpression expr; + expr = gen.generateConstant(column.getType()); + sb.append(HSQLDBToStringVisitor.asString(expr)); + } + +} diff --git a/src/sqlancer/mariadb/MariaDBBugs.java b/src/sqlancer/mariadb/MariaDBBugs.java index 7a568b8fe..5a1bfa5cc 100644 --- a/src/sqlancer/mariadb/MariaDBBugs.java +++ b/src/sqlancer/mariadb/MariaDBBugs.java @@ -5,6 +5,30 @@ public final class MariaDBBugs { // https://jira.mariadb.org/browse/MDEV-21058 public static boolean bug21058 = true; + // https://jira.mariadb.org/browse/MDEV-32076 + public static boolean bug32076 = true; + + // https://jira.mariadb.org/browse/MDEV-32099 + public static boolean bug32099 = true; + + // https://jira.mariadb.org/browse/MDEV-32105 + public static boolean bug32105 = true; + + // https://jira.mariadb.org/browse/MDEV-32106 + public static boolean bug32106 = true; + + // https://jira.mariadb.org/browse/MDEV-32107 + public static boolean bug32107 = true; + + // https://jira.mariadb.org/browse/MDEV-32108 + public static boolean bug32108 = true; + + // https://jira.mariadb.org/browse/MDEV-32143 + public static boolean bug32143 = true; + + // https://jira.mariadb.org/browse/MDEV-33893 + public static boolean bug33893 = true; + private MariaDBBugs() { } diff --git a/src/sqlancer/mariadb/MariaDBErrors.java b/src/sqlancer/mariadb/MariaDBErrors.java index ab1bec4cd..b84deba02 100644 --- a/src/sqlancer/mariadb/MariaDBErrors.java +++ b/src/sqlancer/mariadb/MariaDBErrors.java @@ -1,5 +1,8 @@ package sqlancer.mariadb; +import java.util.ArrayList; +import java.util.List; + import sqlancer.common.query.ExpectedErrors; public final class MariaDBErrors { @@ -7,7 +10,46 @@ public final class MariaDBErrors { private MariaDBErrors() { } - public static void addInsertErrors(ExpectedErrors errors) { + public static List getCommonErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("is out of range"); + // regex + errors.add("unmatched parentheses"); + errors.add("nothing to repeat at offset"); + errors.add("missing )"); + errors.add("missing terminating ]"); + errors.add("range out of order in character class"); + errors.add("unrecognized character after "); + errors.add("Got error '(*VERB) not recognized or malformed"); + errors.add("must be followed by"); + errors.add("malformed number or name after"); + errors.add("digit expected after"); + errors.add("Regex error"); + errors.add("Lock wait timeout exceeded"); + + return errors; + } + + public static void addCommonErrors(ExpectedErrors errors) { + errors.add("is out of range"); + // regex + errors.add("unmatched parentheses"); + errors.add("nothing to repeat at offset"); + errors.add("missing )"); + errors.add("missing terminating ]"); + errors.add("range out of order in character class"); + errors.add("unrecognized character after "); + errors.add("Got error '(*VERB) not recognized or malformed"); + errors.add("must be followed by"); + errors.add("malformed number or name after"); + errors.add("digit expected after"); + errors.add("Regex error"); + errors.add("Lock wait timeout exceeded"); + } + + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); errors.add("Out of range"); errors.add("Duplicate entry"); // violates UNIQUE constraint errors.add("cannot be null"); // violates NOT NULL constraint @@ -17,6 +59,11 @@ public static void addInsertErrors(ExpectedErrors errors) { errors.add("The value specified for generated column"); // trying to insert into a generated column errors.add("Incorrect double value"); errors.add("Incorrect string value"); + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); } } diff --git a/src/sqlancer/mariadb/MariaDBOptions.java b/src/sqlancer/mariadb/MariaDBOptions.java index 02e6eab64..b1e7f807e 100644 --- a/src/sqlancer/mariadb/MariaDBOptions.java +++ b/src/sqlancer/mariadb/MariaDBOptions.java @@ -1,6 +1,5 @@ package sqlancer.mariadb; -import java.sql.SQLException; import java.util.Arrays; import java.util.List; @@ -8,11 +7,6 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.TestOracle; -import sqlancer.mariadb.MariaDBOptions.MariaDBOracleFactory; -import sqlancer.mariadb.MariaDBProvider.MariaDBGlobalState; -import sqlancer.mariadb.oracle.MariaDBNoRECOracle; @Parameters(separators = "=", commandDescription = "MariaDB (default port: " + MariaDBOptions.DEFAULT_PORT + ", default host: " + MariaDBOptions.DEFAULT_HOST + ")") @@ -23,18 +17,6 @@ public class MariaDBOptions implements DBMSSpecificOptions @Parameter(names = "--oracle") public List oracles = Arrays.asList(MariaDBOracleFactory.NOREC); - public enum MariaDBOracleFactory implements OracleFactory { - - NOREC { - - @Override - public TestOracle create(MariaDBGlobalState globalState) throws SQLException { - return new MariaDBNoRECOracle(globalState); - } - - } - } - @Override public List getTestOracleFactory() { return oracles; diff --git a/src/sqlancer/mariadb/MariaDBOracleFactory.java b/src/sqlancer/mariadb/MariaDBOracleFactory.java new file mode 100644 index 000000000..549da27c8 --- /dev/null +++ b/src/sqlancer/mariadb/MariaDBOracleFactory.java @@ -0,0 +1,36 @@ +package sqlancer.mariadb; + +import java.sql.SQLException; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.mariadb.gen.MariaDBExpressionGenerator; +import sqlancer.mariadb.oracle.MariaDBDQPOracle; + +public enum MariaDBOracleFactory implements OracleFactory { + + NOREC { + @Override + public TestOracle create(MariaDBProvider.MariaDBGlobalState globalState) + throws SQLException { + MariaDBExpressionGenerator gen = new MariaDBExpressionGenerator(globalState.getRandomly()); + ExpectedErrors errors = ExpectedErrors.newErrors().with(MariaDBErrors.getCommonErrors()) + .with("is out of range").with("unmatched parentheses").with("nothing to repeat at offset") + .with("missing )").with("missing terminating ]").with("range out of order in character class") + .with("unrecognized character after ").with("Got error '(*VERB) not recognized or malformed") + .with("must be followed by").with("malformed number or name after").with("digit expected after") + .with("Could not create a join buffer").build(); + return new NoRECOracle<>(globalState, gen, errors); + } + + }, + DQP { + @Override + public TestOracle create(MariaDBProvider.MariaDBGlobalState globalState) + throws SQLException { + return new MariaDBDQPOracle(globalState); + } + } +} diff --git a/src/sqlancer/mariadb/MariaDBProvider.java b/src/sqlancer/mariadb/MariaDBProvider.java index c84691203..a9737f549 100644 --- a/src/sqlancer/mariadb/MariaDBProvider.java +++ b/src/sqlancer/mariadb/MariaDBProvider.java @@ -19,6 +19,7 @@ import sqlancer.common.DBMSCommon; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.mariadb.MariaDBProvider.MariaDBGlobalState; +import sqlancer.mariadb.gen.MariaDBDeleteGenerator; import sqlancer.mariadb.gen.MariaDBIndexGenerator; import sqlancer.mariadb.gen.MariaDBInsertGenerator; import sqlancer.mariadb.gen.MariaDBSetGenerator; @@ -47,13 +48,14 @@ enum Action { SET, // TRUNCATE, // UPDATE, // + DELETE, } @Override public void generateDatabase(MariaDBGlobalState globalState) throws Exception { MainOptions options = globalState.getOptions(); - while (globalState.getSchema().getDatabaseTables().size() < Randomly.smallNumber() + 1) { + while (globalState.getSchema().getDatabaseTables().size() < Randomly.getNotCachedInteger(1, 3)) { String tableName = DBMSCommon.createTableName(globalState.getSchema().getDatabaseTables().size()); SQLQueryAdapter createTable = MariaDBTableGenerator.generate(tableName, globalState.getRandomly(), globalState.getSchema()); @@ -77,6 +79,9 @@ public void generateDatabase(MariaDBGlobalState globalState) throws Exception { case CREATE_INDEX: nrPerformed = globalState.getRandomly().getInteger(0, 2); break; + case DELETE: + nrPerformed = globalState.getRandomly().getInteger(0, 2); + break; case SET: nrPerformed = 20; break; @@ -140,6 +145,9 @@ public void generateDatabase(MariaDBGlobalState globalState) throws Exception { case SET: query = MariaDBSetGenerator.set(globalState.getRandomly(), options); break; + case DELETE: + query = MariaDBDeleteGenerator.delete(globalState.getSchema(), globalState.getRandomly()); + break; default: throw new AssertionError(nextAction); } diff --git a/src/sqlancer/mariadb/MariaDBSchema.java b/src/sqlancer/mariadb/MariaDBSchema.java index ed432e8da..7f7656d76 100644 --- a/src/sqlancer/mariadb/MariaDBSchema.java +++ b/src/sqlancer/mariadb/MariaDBSchema.java @@ -51,10 +51,15 @@ public int getPrecision() { return precision; } + @Override public boolean isPrimaryKey() { return isPrimaryKey; } + public static MariaDBColumn createDummy(String name) { + return new MariaDBColumn(name, MariaDBDataType.INT, false, 1); + } + } public static class MariaDBTables { @@ -242,4 +247,8 @@ public MariaDBSchema(List databaseTables) { super(databaseTables); } + public MariaDBTables getRandomTableNonEmptyTables() { + return new MariaDBTables(Randomly.nonEmptySubset(getDatabaseTables())); + } + } diff --git a/src/sqlancer/mariadb/ast/MariaDBAggregate.java b/src/sqlancer/mariadb/ast/MariaDBAggregate.java index 3977a6a95..ab5a47781 100644 --- a/src/sqlancer/mariadb/ast/MariaDBAggregate.java +++ b/src/sqlancer/mariadb/ast/MariaDBAggregate.java @@ -1,6 +1,6 @@ package sqlancer.mariadb.ast; -public class MariaDBAggregate extends MariaDBExpression { +public class MariaDBAggregate implements MariaDBExpression { private final MariaDBExpression expr; private final MariaDBAggregateFunction aggr; diff --git a/src/sqlancer/mariadb/ast/MariaDBBinaryOperator.java b/src/sqlancer/mariadb/ast/MariaDBBinaryOperator.java index 14f3c308a..5f56d168f 100644 --- a/src/sqlancer/mariadb/ast/MariaDBBinaryOperator.java +++ b/src/sqlancer/mariadb/ast/MariaDBBinaryOperator.java @@ -2,7 +2,7 @@ import sqlancer.Randomly; -public class MariaDBBinaryOperator extends MariaDBExpression { +public class MariaDBBinaryOperator implements MariaDBExpression { private MariaDBExpression left; private MariaDBExpression right; diff --git a/src/sqlancer/mariadb/ast/MariaDBColumnName.java b/src/sqlancer/mariadb/ast/MariaDBColumnName.java index a4fd78c1b..65d4cea0c 100644 --- a/src/sqlancer/mariadb/ast/MariaDBColumnName.java +++ b/src/sqlancer/mariadb/ast/MariaDBColumnName.java @@ -2,7 +2,7 @@ import sqlancer.mariadb.MariaDBSchema.MariaDBColumn; -public class MariaDBColumnName extends MariaDBExpression { +public class MariaDBColumnName implements MariaDBExpression { private final MariaDBColumn column; diff --git a/src/sqlancer/mariadb/ast/MariaDBConstant.java b/src/sqlancer/mariadb/ast/MariaDBConstant.java index 6e3d77691..8670c9ed2 100644 --- a/src/sqlancer/mariadb/ast/MariaDBConstant.java +++ b/src/sqlancer/mariadb/ast/MariaDBConstant.java @@ -1,6 +1,9 @@ package sqlancer.mariadb.ast; -public class MariaDBConstant extends MariaDBExpression { +public class MariaDBConstant implements MariaDBExpression { + + private MariaDBConstant() { + } public static class MariaDBNullConstant extends MariaDBConstant { diff --git a/src/sqlancer/mariadb/ast/MariaDBExpression.java b/src/sqlancer/mariadb/ast/MariaDBExpression.java index beff58866..d57c14888 100644 --- a/src/sqlancer/mariadb/ast/MariaDBExpression.java +++ b/src/sqlancer/mariadb/ast/MariaDBExpression.java @@ -1,5 +1,8 @@ package sqlancer.mariadb.ast; -public class MariaDBExpression { +import sqlancer.common.ast.newast.Expression; +import sqlancer.mariadb.MariaDBSchema.MariaDBColumn; + +public interface MariaDBExpression extends Expression { } diff --git a/src/sqlancer/mariadb/ast/MariaDBFunction.java b/src/sqlancer/mariadb/ast/MariaDBFunction.java index fd033d310..481a26aaf 100644 --- a/src/sqlancer/mariadb/ast/MariaDBFunction.java +++ b/src/sqlancer/mariadb/ast/MariaDBFunction.java @@ -2,7 +2,7 @@ import java.util.List; -public class MariaDBFunction extends MariaDBExpression { +public class MariaDBFunction implements MariaDBExpression { private final MariaDBFunctionName func; private final List args; diff --git a/src/sqlancer/mariadb/ast/MariaDBInOperation.java b/src/sqlancer/mariadb/ast/MariaDBInOperation.java index 85981b390..15aad4daa 100644 --- a/src/sqlancer/mariadb/ast/MariaDBInOperation.java +++ b/src/sqlancer/mariadb/ast/MariaDBInOperation.java @@ -2,7 +2,7 @@ import java.util.List; -public class MariaDBInOperation extends MariaDBExpression { +public class MariaDBInOperation implements MariaDBExpression { private final MariaDBExpression expr; private final List list; diff --git a/src/sqlancer/mariadb/ast/MariaDBJoin.java b/src/sqlancer/mariadb/ast/MariaDBJoin.java new file mode 100644 index 000000000..f1d0892e1 --- /dev/null +++ b/src/sqlancer/mariadb/ast/MariaDBJoin.java @@ -0,0 +1,86 @@ +package sqlancer.mariadb.ast; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; +import sqlancer.mariadb.MariaDBSchema.MariaDBColumn; +import sqlancer.mariadb.MariaDBSchema.MariaDBTable; +import sqlancer.mariadb.gen.MariaDBExpressionGenerator; + +public class MariaDBJoin implements MariaDBExpression, Join { + + public enum JoinType { + NATURAL, INNER, STRAIGHT, LEFT, RIGHT, CROSS; + } + + private final MariaDBTable table; + private MariaDBExpression onClause; + private JoinType type; + + public MariaDBJoin(MariaDBJoin other) { + this.table = other.table; + this.onClause = other.onClause; + this.type = other.type; + } + + public MariaDBJoin(MariaDBTable table, MariaDBExpression onClause, JoinType type) { + this.table = table; + this.onClause = onClause; + this.type = type; + } + + public MariaDBTable getTable() { + return table; + } + + public MariaDBExpression getOnClause() { + return onClause; + } + + public JoinType getType() { + return type; + } + + @Override + public void setOnClause(MariaDBExpression onClause) { + this.onClause = onClause; + } + + public void setType(JoinType type) { + this.type = type; + } + + public static List getRandomJoinClauses(List tables, Randomly r) { + List joinStatements = new ArrayList<>(); + List options = new ArrayList<>(Arrays.asList(JoinType.values())); + List columns = new ArrayList<>(); + if (tables.size() > 1) { + int nrJoinClauses = (int) Randomly.getNotCachedInteger(0, tables.size()); + // Natural join is incompatible with other joins + // because it needs unique column names + // while other joins will produce duplicate column names + if (nrJoinClauses > 1) { + options.remove(JoinType.NATURAL); + } + for (int i = 0; i < nrJoinClauses; i++) { + MariaDBTable table = Randomly.fromList(tables); + tables.remove(table); + columns.addAll(table.getColumns()); + MariaDBExpressionGenerator joinGen = new MariaDBExpressionGenerator(r).setColumns(columns); + MariaDBExpression joinClause = joinGen.getRandomExpression(); + JoinType selectedOption = Randomly.fromList(options); + if (selectedOption == JoinType.NATURAL) { + // NATURAL joins do not have an ON clause + joinClause = null; + } + MariaDBJoin j = new MariaDBJoin(table, joinClause, selectedOption); + joinStatements.add(j); + } + + } + return joinStatements; + } +} diff --git a/src/sqlancer/mariadb/ast/MariaDBPostfixUnaryOperation.java b/src/sqlancer/mariadb/ast/MariaDBPostfixUnaryOperation.java index cd655a5d1..9b56a44dd 100644 --- a/src/sqlancer/mariadb/ast/MariaDBPostfixUnaryOperation.java +++ b/src/sqlancer/mariadb/ast/MariaDBPostfixUnaryOperation.java @@ -2,7 +2,7 @@ import sqlancer.Randomly; -public class MariaDBPostfixUnaryOperation extends MariaDBExpression { +public class MariaDBPostfixUnaryOperation implements MariaDBExpression { private MariaDBPostfixUnaryOperator operator; private MariaDBExpression randomWhereCondition; diff --git a/src/sqlancer/mariadb/ast/MariaDBSelectStatement.java b/src/sqlancer/mariadb/ast/MariaDBSelectStatement.java index d1700f508..f09e7051f 100644 --- a/src/sqlancer/mariadb/ast/MariaDBSelectStatement.java +++ b/src/sqlancer/mariadb/ast/MariaDBSelectStatement.java @@ -3,37 +3,40 @@ import java.util.ArrayList; import java.util.List; +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.mariadb.MariaDBSchema.MariaDBColumn; import sqlancer.mariadb.MariaDBSchema.MariaDBTable; -public class MariaDBSelectStatement extends MariaDBExpression { +public class MariaDBSelectStatement extends SelectBase + implements MariaDBExpression, Select { public enum MariaDBSelectType { - ALL + ALL, DISTINCT, DISTINCTROW; } private List groupBys = new ArrayList<>(); private List columns = new ArrayList<>(); - private List tables = new ArrayList<>(); + private List joinClauses = new ArrayList<>(); private MariaDBSelectType selectType = MariaDBSelectType.ALL; private MariaDBExpression whereCondition; + @Override public void setGroupByClause(List groupBys) { this.groupBys = groupBys; } + @Override public void setFetchColumns(List columns) { this.columns = columns; } - public void setFromTables(List tables) { - this.tables = tables; - } - public void setSelectType(MariaDBSelectType selectType) { this.selectType = selectType; } + @Override public void setWhereClause(MariaDBExpression whereCondition) { this.whereCondition = whereCondition; } @@ -50,12 +53,22 @@ public MariaDBSelectType getSelectType() { return selectType; } - public List getTables() { - return tables; - } - public MariaDBExpression getWhereCondition() { return whereCondition; } + @Override + public List getJoinClauses() { + return joinClauses; + } + + @Override + public void setJoinClauses(List joinClauses) { + this.joinClauses = joinClauses; + } + + @Override + public String asString() { + return MariaDBVisitor.asString(this); + } } diff --git a/src/sqlancer/mariadb/ast/MariaDBStringVisitor.java b/src/sqlancer/mariadb/ast/MariaDBStringVisitor.java index 53ed490d0..f7fed0c83 100644 --- a/src/sqlancer/mariadb/ast/MariaDBStringVisitor.java +++ b/src/sqlancer/mariadb/ast/MariaDBStringVisitor.java @@ -1,7 +1,6 @@ package sqlancer.mariadb.ast; import java.util.List; -import java.util.stream.Collectors; public class MariaDBStringVisitor extends MariaDBVisitor { @@ -27,7 +26,7 @@ public void visit(MariaDBPostfixUnaryOperation op) { @Override public void visit(MariaDBColumnName c) { - sb.append(c.getColumn().getName()); + sb.append(c.getColumn().getFullQualifiedName()); } @Override @@ -41,13 +40,22 @@ public void visit(MariaDBSelectStatement s) { visit(column); } sb.append(" FROM "); - sb.append(s.getTables().stream().map(t -> t.getName()).collect(Collectors.joining(", "))); + + for (int j = 0; j < s.getFromList().size(); j++) { + if (j != 0) { + sb.append(", "); + } + visit(s.getFromList().get(j)); + } + for (MariaDBExpression j : s.getJoinClauses()) { + visit(j); + } if (s.getWhereCondition() != null) { sb.append(" WHERE "); visit(s.getWhereCondition()); } - if (s.getGroupBys().size() != 0) { - sb.append(" GROUP BY"); + if (!s.getGroupBys().isEmpty()) { + sb.append(" GROUP BY "); for (i = 0; i < s.getGroupBys().size(); i++) { if (i != 0) { sb.append(", "); @@ -131,4 +139,41 @@ private void visitList(List list) { } } + @Override + public void visit(MariaDBJoin join) { + sb.append(" "); + switch (join.getType()) { + case NATURAL: + sb.append("NATURAL "); + break; + case INNER: + sb.append("INNER "); + break; + case STRAIGHT: + sb.append("STRAIGHT_"); + break; + case LEFT: + sb.append("LEFT "); + break; + case RIGHT: + sb.append("RIGHT "); + break; + case CROSS: + sb.append("CROSS "); + break; + default: + throw new AssertionError(join.getType()); + } + sb.append("JOIN "); + sb.append(join.getTable().getName()); + if (join.getOnClause() != null) { + sb.append(" ON "); + visit(join.getOnClause()); + } + } + + @Override + public void visit(MariaDBTableReference ref) { + sb.append(ref.getTable().getName()); + } } diff --git a/src/sqlancer/mariadb/ast/MariaDBTableReference.java b/src/sqlancer/mariadb/ast/MariaDBTableReference.java new file mode 100644 index 000000000..f045907ca --- /dev/null +++ b/src/sqlancer/mariadb/ast/MariaDBTableReference.java @@ -0,0 +1,16 @@ +package sqlancer.mariadb.ast; + +import sqlancer.mariadb.MariaDBSchema.MariaDBTable; + +public class MariaDBTableReference implements MariaDBExpression { + + private final MariaDBTable table; + + public MariaDBTableReference(MariaDBTable table) { + this.table = table; + } + + public MariaDBTable getTable() { + return table; + } +} diff --git a/src/sqlancer/mariadb/ast/MariaDBText.java b/src/sqlancer/mariadb/ast/MariaDBText.java index b96871063..c3d1c2d1a 100644 --- a/src/sqlancer/mariadb/ast/MariaDBText.java +++ b/src/sqlancer/mariadb/ast/MariaDBText.java @@ -1,6 +1,6 @@ package sqlancer.mariadb.ast; -public class MariaDBText extends MariaDBExpression { +public class MariaDBText implements MariaDBExpression { private final MariaDBExpression expr; private final String text; diff --git a/src/sqlancer/mariadb/ast/MariaDBUnaryPrefixOperation.java b/src/sqlancer/mariadb/ast/MariaDBUnaryPrefixOperation.java index af229ce6e..dda61dbee 100644 --- a/src/sqlancer/mariadb/ast/MariaDBUnaryPrefixOperation.java +++ b/src/sqlancer/mariadb/ast/MariaDBUnaryPrefixOperation.java @@ -2,7 +2,7 @@ import sqlancer.Randomly; -public class MariaDBUnaryPrefixOperation extends MariaDBExpression { +public class MariaDBUnaryPrefixOperation implements MariaDBExpression { private MariaDBExpression expr; private MariaDBUnaryPrefixOperator op; diff --git a/src/sqlancer/mariadb/ast/MariaDBVisitor.java b/src/sqlancer/mariadb/ast/MariaDBVisitor.java index 8ea3f9d37..8626dc967 100644 --- a/src/sqlancer/mariadb/ast/MariaDBVisitor.java +++ b/src/sqlancer/mariadb/ast/MariaDBVisitor.java @@ -22,6 +22,10 @@ public abstract class MariaDBVisitor { public abstract void visit(MariaDBInOperation op); + public abstract void visit(MariaDBJoin join); + + public abstract void visit(MariaDBTableReference join); + public void visit(MariaDBExpression expr) { if (expr instanceof MariaDBConstant) { visit((MariaDBConstant) expr); @@ -43,6 +47,10 @@ public void visit(MariaDBExpression expr) { visit((MariaDBFunction) expr); } else if (expr instanceof MariaDBInOperation) { visit((MariaDBInOperation) expr); + } else if (expr instanceof MariaDBJoin) { + visit((MariaDBJoin) expr); + } else if (expr instanceof MariaDBTableReference) { + visit((MariaDBTableReference) expr); } else { throw new AssertionError(expr.getClass()); } diff --git a/src/sqlancer/mariadb/gen/MariaDBDeleteGenerator.java b/src/sqlancer/mariadb/gen/MariaDBDeleteGenerator.java new file mode 100644 index 000000000..6d85eb891 --- /dev/null +++ b/src/sqlancer/mariadb/gen/MariaDBDeleteGenerator.java @@ -0,0 +1,93 @@ +package sqlancer.mariadb.gen; + +import java.util.Collections; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.AbstractTables; +import sqlancer.mariadb.MariaDBSchema; +import sqlancer.mariadb.MariaDBSchema.MariaDBColumn; +import sqlancer.mariadb.MariaDBSchema.MariaDBTable; +import sqlancer.mariadb.ast.MariaDBVisitor; + +public final class MariaDBDeleteGenerator { + + private MariaDBDeleteGenerator() { + } + + public static SQLQueryAdapter delete(MariaDBSchema schema, Randomly r) { + MariaDBTable table = schema.getRandomTable(); + + MariaDBExpressionGenerator expressionGenerator = new MariaDBExpressionGenerator(r); + + AbstractTables tablesAndColumns = new AbstractTables<>( + Collections.singletonList(table)); + expressionGenerator.setTablesAndColumns(tablesAndColumns); + + ExpectedErrors errors = new ExpectedErrors(); + + errors.add("foreign key constraint fails"); + errors.add("cannot delete or update a parent row"); + errors.add("Data truncated"); + errors.add("Division by 0"); + errors.add("Incorrect value"); + + StringBuilder sb = new StringBuilder("DELETE"); + + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" LOW_PRIORITY"); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" QUICK"); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" IGNORE"); + } + + sb.append(" FROM "); + sb.append(table.getName()); + + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(MariaDBVisitor.asString(MariaDBExpressionGenerator.getRandomConstant(r))); + } else { + sb.append(MariaDBVisitor.asString(expressionGenerator.getRandomExpression())); + } + } + + // ORDER BY + LIMIT + if (Randomly.getBooleanWithRatherLowProbability() && !table.getColumns().isEmpty()) { + sb.append(" ORDER BY "); + sb.append(Randomly.fromList(table.getColumns()).getName()); + if (Randomly.getBoolean()) { + sb.append(Randomly.getBoolean() ? " ASC" : " DESC"); + } + } + + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" LIMIT "); + sb.append(Randomly.getNotCachedInteger(1, 10)); + } + + // RETURNING clause (MariaDB >= 10.5) + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" RETURNING "); + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(MariaDBVisitor.asString(MariaDBExpressionGenerator.getRandomConstant(r))); + } else { + sb.append(MariaDBVisitor.asString(expressionGenerator.getRandomExpression())); + } + } + + String query = sb.toString(); + if (query.contains("RLIKE") || query.contains("REGEXP")) { + errors.add("Regex error"); + errors.add("quantifier does not follow a repeatable item"); + errors.add("Got error"); + } + + return new SQLQueryAdapter(query, errors); + } +} diff --git a/src/sqlancer/mariadb/gen/MariaDBExpressionGenerator.java b/src/sqlancer/mariadb/gen/MariaDBExpressionGenerator.java index 062bf509c..3b7386121 100644 --- a/src/sqlancer/mariadb/gen/MariaDBExpressionGenerator.java +++ b/src/sqlancer/mariadb/gen/MariaDBExpressionGenerator.java @@ -5,11 +5,14 @@ import java.util.List; import sqlancer.Randomly; -import sqlancer.SQLConnection; -import sqlancer.StateToReproduce; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.schema.AbstractTables; import sqlancer.mariadb.MariaDBProvider; import sqlancer.mariadb.MariaDBSchema.MariaDBColumn; import sqlancer.mariadb.MariaDBSchema.MariaDBDataType; +import sqlancer.mariadb.MariaDBSchema.MariaDBTable; +import sqlancer.mariadb.ast.MariaDBAggregate; +import sqlancer.mariadb.ast.MariaDBAggregate.MariaDBAggregateFunction; import sqlancer.mariadb.ast.MariaDBBinaryOperator; import sqlancer.mariadb.ast.MariaDBBinaryOperator.MariaDBBinaryComparisonOperator; import sqlancer.mariadb.ast.MariaDBColumnName; @@ -18,14 +21,21 @@ import sqlancer.mariadb.ast.MariaDBFunction; import sqlancer.mariadb.ast.MariaDBFunctionName; import sqlancer.mariadb.ast.MariaDBInOperation; +import sqlancer.mariadb.ast.MariaDBJoin; import sqlancer.mariadb.ast.MariaDBPostfixUnaryOperation; import sqlancer.mariadb.ast.MariaDBPostfixUnaryOperation.MariaDBPostfixUnaryOperator; +import sqlancer.mariadb.ast.MariaDBSelectStatement; +import sqlancer.mariadb.ast.MariaDBSelectStatement.MariaDBSelectType; +import sqlancer.mariadb.ast.MariaDBTableReference; +import sqlancer.mariadb.ast.MariaDBText; import sqlancer.mariadb.ast.MariaDBUnaryPrefixOperation; import sqlancer.mariadb.ast.MariaDBUnaryPrefixOperation.MariaDBUnaryPrefixOperator; -public class MariaDBExpressionGenerator { +public class MariaDBExpressionGenerator + implements NoRECGenerator { private final Randomly r; + private List targetTables = new ArrayList<>(); private List columns = new ArrayList<>(); public MariaDBExpressionGenerator(Randomly r) { @@ -66,14 +76,6 @@ public MariaDBExpressionGenerator setColumns(List columns) { return this; } - public MariaDBExpressionGenerator setCon(SQLConnection con) { - return this; - } - - public MariaDBExpressionGenerator setState(StateToReproduce state) { - return this; - } - private enum ExpressionType { LITERAL, COLUMN, BINARY_COMPARISON, UNARY_POSTFIX_OPERATOR, UNARY_PREFIX_OPERATOR, FUNCTION, IN } @@ -146,4 +148,64 @@ public MariaDBExpression getRandomExpression() { return getRandomExpression(0); } + @Override + public MariaDBExpressionGenerator setTablesAndColumns(AbstractTables targetTables) { + this.targetTables = targetTables.getTables(); + this.columns = targetTables.getColumns(); + return this; + } + + @Override + public List getTableRefs() { + List tableRefs = new ArrayList<>(); + for (MariaDBTable t : targetTables) { + MariaDBTableReference tableRef = new MariaDBTableReference(t); + tableRefs.add(tableRef); + } + return tableRefs; + } + + @Override + public MariaDBExpression generateBooleanExpression() { + return getRandomExpression(); + } + + @Override + public MariaDBSelectStatement generateSelect() { + return new MariaDBSelectStatement(); + } + + @Override + public List getRandomJoinClauses() { + return MariaDBJoin.getRandomJoinClauses(targetTables, r); + } + + @Override + public String generateOptimizedQueryString(MariaDBSelectStatement select, MariaDBExpression whereCondition, + boolean shouldUseAggregate) { + if (shouldUseAggregate) { + MariaDBAggregate aggr = new MariaDBAggregate( + new MariaDBColumnName(new MariaDBColumn("*", MariaDBDataType.INT, false, 0)), + MariaDBAggregateFunction.COUNT); + select.setFetchColumns(Arrays.asList(aggr)); + } else { + MariaDBColumnName aggr = new MariaDBColumnName(MariaDBColumn.createDummy("*")); + select.setFetchColumns(Arrays.asList(aggr)); + } + + select.setWhereClause(whereCondition); + select.setSelectType(MariaDBSelectType.ALL); + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(MariaDBSelectStatement select, MariaDBExpression whereCondition) { + MariaDBPostfixUnaryOperation isTrue = new MariaDBPostfixUnaryOperation(MariaDBPostfixUnaryOperator.IS_TRUE, + whereCondition); + MariaDBText asText = new MariaDBText(isTrue, " as count", false); + select.setFetchColumns(Arrays.asList(asText)); + select.setSelectType(MariaDBSelectType.ALL); + + return "SELECT SUM(count) FROM (" + select.asString() + ") as asdf"; + } } diff --git a/src/sqlancer/mariadb/gen/MariaDBIndexGenerator.java b/src/sqlancer/mariadb/gen/MariaDBIndexGenerator.java index 19519ca2e..1ba3fbd4d 100644 --- a/src/sqlancer/mariadb/gen/MariaDBIndexGenerator.java +++ b/src/sqlancer/mariadb/gen/MariaDBIndexGenerator.java @@ -19,6 +19,7 @@ public static SQLQueryAdapter generate(MariaDBSchema s) { ExpectedErrors errors = new ExpectedErrors(); StringBuilder sb = new StringBuilder("CREATE "); errors.add("Key/Index cannot be defined on a virtual generated column"); + errors.add("Specified key was too long"); if (Randomly.getBoolean()) { errors.add("Duplicate entry"); errors.add("Key/Index cannot be defined on a virtual generated column"); diff --git a/src/sqlancer/mariadb/gen/MariaDBSetGenerator.java b/src/sqlancer/mariadb/gen/MariaDBSetGenerator.java index d437d4537..860f60c4f 100644 --- a/src/sqlancer/mariadb/gen/MariaDBSetGenerator.java +++ b/src/sqlancer/mariadb/gen/MariaDBSetGenerator.java @@ -1,5 +1,6 @@ package sqlancer.mariadb.gen; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -9,6 +10,8 @@ import sqlancer.Randomly; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.mariadb.MariaDBBugs; +import sqlancer.mariadb.MariaDBProvider.MariaDBGlobalState; public class MariaDBSetGenerator { @@ -35,13 +38,12 @@ private enum Action { AUTOCOMMIT("autocommit", (r) -> 1, Scope.GLOBAL, Scope.SESSION), // BIG_TABLES("big_tables", (r) -> Randomly.fromOptions("OFF", "ON"), Scope.GLOBAL, Scope.SESSION), // - COMPLETION_TYPE("completion_type", (r) -> Randomly.fromOptions("'NO_CHAIN'", "'CHAIN'", "'RELEASE'", 0, 1, 2), - Scope.GLOBAL), // + COMPLETION_TYPE("completion_type", + (r) -> Randomly.fromOptions("'NO_CHAIN'", "'CHAIN'", "'RELEASE'", "0", "1", "2"), Scope.GLOBAL), // // BULK_INSERT_CACHE_SIZE("bulk_insert_buffer_size", (r) -> r.getLong(0, Long.MAX_VALUE), Scope.GLOBAL, // Scope.SESSION), - CONCURRENT_INSERT("concurrent_insert", (r) -> Randomly.fromOptions("NEVER", "AUTO", "ALWAYS", 0, 1, 2), + CONCURRENT_INSERT("concurrent_insert", (r) -> Randomly.fromOptions("NEVER", "AUTO", "ALWAYS", "0", "1", "2"), Scope.GLOBAL), - CTE_MAX_RECURSION_DEPTH("cte_max_recursion_depth", (r) -> r.getLong(0, 4294967295L), Scope.GLOBAL), DELAY_KEY_WRITE("delay_key_write", (r) -> Randomly.fromOptions("ON", "OFF", "ALL"), Scope.GLOBAL), EQ_RANGE_INDEX_DIVE_LIMIT("eq_range_index_dive_limit", (r) -> r.getLong(0, 4294967295L), Scope.GLOBAL), FLUSH("flush", (r) -> Randomly.fromOptions("OFF", "ON"), Scope.GLOBAL), @@ -86,7 +88,6 @@ private enum Action { */ // READ_BUFFER_SIZE("read_buffer_size", (r) -> r.getLong(8200, 2147479552), Scope.GLOBAL, Scope.SESSION), // READ_RND_BUFFER_SIZE("read_rnd_buffer_size", (r) -> r.getLong(1, 2147483647), Scope.GLOBAL, Scope.SESSION), - SCHEMA_DEFINITION_CACHE("schema_definition_cache", (r) -> r.getLong(256, 524288), Scope.GLOBAL), /* * sort_buffer_size is commented out as a workaround for https://bugs.mysql.com/bug.php?id=95969 */ @@ -120,27 +121,15 @@ private enum Action { private static String getOptimizerSwitchConfiguration(Randomly r) { StringBuilder sb = new StringBuilder(); sb.append("'"); - String[] options = { /* - * ("batched_key_access", /*"block_nested_loop", "condition_fanout_filter", - */ - "condition_pushdown_for_derived", // MariaDB - "derived_merge", // - "derived_with_keys", // MariaDB - "engine_condition_pushdown", // - "exists_to_in", // MariaDB - "extended_keys", // MariaDB - "firstmatch", // MariaDB - "index_condition_pushdown", // - /* "use_index_extensions", */ - "index_merge", // - "index_merge_intersection", // - "index_merge_sort_intersection", // - "index_merge_sort_union", // - "index_merge_union", "in_to_exists", // MariaDB - /* "use_invisible_indexes", */ "mrr", "mrr_cost_based", /* "skip_scan", */ "semijoin", /* - * "duplicateweedout", - */ - "firstmatch", "loosescan", "materialization", /* "subquery_materialization_cost_based" */ }; + String[] options = { "condition_pushdown_for_derived", "condition_pushdown_for_subquery", + "condition_pushdown_from_having", "derived_merge", "derived_with_keys", "exists_to_in", + "extended_keys", "firstmatch", "index_condition_pushdown", "hash_join_cardinality", "index_merge", + "index_merge_intersection", "index_merge_sort_intersection", "index_merge_sort_union", + "index_merge_union", "in_to_exists", "join_cache_bka", "join_cache_hashed", + "join_cache_incremental", "loosescan", "materialization", "mrr", "mrr_cost_based", "mrr_sort_keys", + "not_null_range_scan", "optimize_join_buffer_size", "orderby_uses_equalities", + "outer_join_with_cache", "partial_match_rowid_merge", "partial_match_table_scan", "rowid_filter", + "semijoin", "semijoin_with_cache", "split_materialized", "subquery_cache", "table_elimination" }; List optionSubset = Arrays.asList(Randomly.fromOptions(options)); sb.append(optionSubset.stream().map(s -> s + "=" + Randomly.fromOptions("on", "off")) .collect(Collectors.joining(","))); @@ -194,4 +183,60 @@ private SQLQueryAdapter get() { .from("At least one of the 'in_to_exists' or 'materialization' optimizer_switch flags must be 'on'")); } + public static SQLQueryAdapter resetOptimizer() { + return new SQLQueryAdapter("SET optimizer_switch='default'"); + } + + public static List getAllOptimizer(MariaDBGlobalState globalState) { + List result = new ArrayList<>(); + String[] options = { "condition_pushdown_for_derived", "condition_pushdown_for_subquery", + "condition_pushdown_from_having", "derived_merge", "derived_with_keys", "exists_to_in", "extended_keys", + "firstmatch", "index_condition_pushdown", "hash_join_cardinality", "index_merge", + "index_merge_intersection", "index_merge_sort_intersection", "index_merge_sort_union", + "index_merge_union", "in_to_exists", "join_cache_bka", "join_cache_hashed", "join_cache_incremental", + "loosescan", "materialization", "mrr", "mrr_cost_based", "mrr_sort_keys", "not_null_range_scan", + "optimize_join_buffer_size", "orderby_uses_equalities", "outer_join_with_cache", + "partial_match_rowid_merge", "partial_match_table_scan", "rowid_filter", "semijoin", + "semijoin_with_cache", "split_materialized", "subquery_cache", "table_elimination" }; + List availableOptions = new ArrayList<>(Arrays.asList(options)); + if (MariaDBBugs.bug21058) { + availableOptions.remove("in_to_exists"); // https://jira.mariadb.org/browse/MDEV-21058 + } + if (MariaDBBugs.bug32076) { + availableOptions.remove("not_null_range_scan"); // https://jira.mariadb.org/browse/MDEV-32076 + } + if (MariaDBBugs.bug32099) { + availableOptions.remove("optimize_join_buffer_size"); // https://jira.mariadb.org/browse/MDEV-32099 + } + if (MariaDBBugs.bug32105) { + availableOptions.remove("join_cache_hashed"); // https://jira.mariadb.org/browse/MDEV-32105 + } + if (MariaDBBugs.bug32106) { + availableOptions.remove("outer_join_with_cache"); // https://jira.mariadb.org/browse/MDEV-32106 + } + if (MariaDBBugs.bug32107) { + availableOptions.remove("table_elimination"); // https://jira.mariadb.org/browse/MDEV-32107 + } + if (MariaDBBugs.bug32108) { + availableOptions.remove("join_cache_incremental"); // https://jira.mariadb.org/browse/MDEV-32108 + } + if (MariaDBBugs.bug32143) { + availableOptions.remove("mrr"); // https://jira.mariadb.org/browse/MDEV-32143 + } + + StringBuilder sb = new StringBuilder(); + sb.append("SET SESSION optimizer_switch = '%s'"); + + for (String option : availableOptions) { + result.add(new SQLQueryAdapter(String.format(sb.toString(), option + "=on"), ExpectedErrors.from( + "At least one of the 'in_to_exists' or 'materialization' optimizer_switch flags must be 'on'"))); + result.add(new SQLQueryAdapter(String.format(sb.toString(), option + "=off"), ExpectedErrors.from( + "At least one of the 'in_to_exists' or 'materialization' optimizer_switch flags must be 'on'"))); + result.add(new SQLQueryAdapter(String.format(sb.toString(), option + "=default"), ExpectedErrors.from( + "At least one of the 'in_to_exists' or 'materialization' optimizer_switch flags must be 'on'"))); + } + + return result; + } + } diff --git a/src/sqlancer/mariadb/gen/MariaDBTableAdminCommandGenerator.java b/src/sqlancer/mariadb/gen/MariaDBTableAdminCommandGenerator.java index bff83e1e0..dd40739f4 100644 --- a/src/sqlancer/mariadb/gen/MariaDBTableAdminCommandGenerator.java +++ b/src/sqlancer/mariadb/gen/MariaDBTableAdminCommandGenerator.java @@ -8,6 +8,7 @@ import sqlancer.Randomly; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLQueryResultCheckAdapter; +import sqlancer.mariadb.MariaDBBugs; import sqlancer.mariadb.MariaDBSchema; import sqlancer.mariadb.MariaDBSchema.MariaDBTable; @@ -54,10 +55,13 @@ public static SQLQueryAdapter checkTable(MariaDBSchema newSchema) { public static SQLQueryAdapter optimizeTable(MariaDBSchema newSchema) { StringBuilder sb = addCommandAndTables(newSchema, "OPTIMIZE TABLE"); - MariaDBCommon.addWaitClause(sb); + if (!MariaDBBugs.bug33893) { + MariaDBCommon.addWaitClause(sb); + } return checkForMsgText(sb, s -> s.equals("OK") || s.equals("Table does not support optimize, doing recreate + analyze instead") - || s.contentEquals("Table is already up to date")); + || s.contentEquals("Table is already up to date") || s.contains("Lock wait timeout") + || s.contains("Operation failed")); } private static SQLQueryAdapter checkForMsgText(StringBuilder sb, Function checker) { diff --git a/src/sqlancer/mariadb/gen/MariaDBTableGenerator.java b/src/sqlancer/mariadb/gen/MariaDBTableGenerator.java index 4376b62db..23f09a420 100644 --- a/src/sqlancer/mariadb/gen/MariaDBTableGenerator.java +++ b/src/sqlancer/mariadb/gen/MariaDBTableGenerator.java @@ -142,6 +142,7 @@ private void createOrReplaceTable() { sb.append("IF NOT EXISTS "); } sb.append(tableName); + errors.add("Specified key was too long; max key length is"); } } diff --git a/src/sqlancer/mariadb/gen/MariaDBTruncateGenerator.java b/src/sqlancer/mariadb/gen/MariaDBTruncateGenerator.java index 04e567444..ecc240bb3 100644 --- a/src/sqlancer/mariadb/gen/MariaDBTruncateGenerator.java +++ b/src/sqlancer/mariadb/gen/MariaDBTruncateGenerator.java @@ -1,6 +1,8 @@ package sqlancer.mariadb.gen; +import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.mariadb.MariaDBErrors; import sqlancer.mariadb.MariaDBSchema; public final class MariaDBTruncateGenerator { @@ -13,7 +15,9 @@ public static SQLQueryAdapter truncate(MariaDBSchema s) { sb.append(s.getRandomTable().getName()); sb.append(" "); MariaDBCommon.addWaitClause(sb); - return new SQLQueryAdapter(sb.toString()); + ExpectedErrors errors = new ExpectedErrors(); + MariaDBErrors.addCommonErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); } } diff --git a/src/sqlancer/mariadb/oracle/MariaDBDQPOracle.java b/src/sqlancer/mariadb/oracle/MariaDBDQPOracle.java new file mode 100644 index 000000000..dcd458193 --- /dev/null +++ b/src/sqlancer/mariadb/oracle/MariaDBDQPOracle.java @@ -0,0 +1,92 @@ +package sqlancer.mariadb.oracle; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.mariadb.MariaDBErrors; +import sqlancer.mariadb.MariaDBProvider.MariaDBGlobalState; +import sqlancer.mariadb.MariaDBSchema; +import sqlancer.mariadb.MariaDBSchema.MariaDBTables; +import sqlancer.mariadb.ast.MariaDBColumnName; +import sqlancer.mariadb.ast.MariaDBExpression; +import sqlancer.mariadb.ast.MariaDBJoin; +import sqlancer.mariadb.ast.MariaDBSelectStatement; +import sqlancer.mariadb.ast.MariaDBTableReference; +import sqlancer.mariadb.ast.MariaDBVisitor; +import sqlancer.mariadb.gen.MariaDBExpressionGenerator; +import sqlancer.mariadb.gen.MariaDBSetGenerator; + +public class MariaDBDQPOracle implements TestOracle { + private final MariaDBGlobalState state; + private final MariaDBSchema s; + private MariaDBExpressionGenerator gen; + private MariaDBSelectStatement select; + private final ExpectedErrors errors = new ExpectedErrors(); + + public MariaDBDQPOracle(MariaDBGlobalState globalState) { + state = globalState; + s = globalState.getSchema(); + MariaDBErrors.addCommonErrors(errors); + } + + @Override + public void check() throws Exception { + MariaDBTables tables = s.getRandomTableNonEmptyTables(); + gen = new MariaDBExpressionGenerator(state.getRandomly()).setColumns(tables.getColumns()); + + List fetchColumns = new ArrayList<>(); + fetchColumns.addAll(Randomly.nonEmptySubset(tables.getColumns()).stream().map(c -> new MariaDBColumnName(c)) + .collect(Collectors.toList())); + + select = new MariaDBSelectStatement(); + select.setFetchColumns(fetchColumns); + + select.setSelectType(Randomly.fromOptions(MariaDBSelectStatement.MariaDBSelectType.values())); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.getRandomExpression()); + } + if (Randomly.getBoolean()) { + select.setGroupByClause(fetchColumns); + } + + // Set the join. + List joinExpressions = MariaDBJoin.getRandomJoinClauses(tables.getTables(), state.getRandomly()); + select.setJoinClauses(joinExpressions); + + // Set the from clause from the tables that are not used in the join. + select.setFromList( + tables.getTables().stream().map(t -> new MariaDBTableReference(t)).collect(Collectors.toList())); + + // Get the result of the first query + String originalQueryString = MariaDBVisitor.asString(select); + List originalResult = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, + state); + + List optimizationList = MariaDBSetGenerator.getAllOptimizer(state); + for (SQLQueryAdapter optimization : optimizationList) { + optimization.execute(state); + List result = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + try { + ComparatorHelper.assumeResultSetsAreEqual(originalResult, result, originalQueryString, + List.of(originalQueryString), state); + } catch (AssertionError e) { + String assertionMessage = String.format( + "The size of the result sets mismatch (%d and %d)!" + System.lineSeparator() + + "First query: \"%s\", whose cardinality is: %d" + System.lineSeparator() + + "Second query:\"%s\", whose cardinality is: %d", + originalResult.size(), result.size(), originalQueryString, originalResult.size(), + String.join(";", originalQueryString), result.size()); + assertionMessage += System.lineSeparator() + "The setting: " + optimization.getQueryString(); + throw new AssertionError(assertionMessage); + } + } + + } + +} diff --git a/src/sqlancer/mariadb/oracle/MariaDBNoRECOracle.java b/src/sqlancer/mariadb/oracle/MariaDBNoRECOracle.java deleted file mode 100644 index 882b49c08..000000000 --- a/src/sqlancer/mariadb/oracle/MariaDBNoRECOracle.java +++ /dev/null @@ -1,125 +0,0 @@ -package sqlancer.mariadb.oracle; - -import java.sql.SQLException; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import sqlancer.IgnoreMeException; -import sqlancer.common.oracle.NoRECBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.common.query.SQLQueryAdapter; -import sqlancer.common.query.SQLancerResultSet; -import sqlancer.mariadb.MariaDBProvider.MariaDBGlobalState; -import sqlancer.mariadb.MariaDBSchema; -import sqlancer.mariadb.MariaDBSchema.MariaDBColumn; -import sqlancer.mariadb.MariaDBSchema.MariaDBDataType; -import sqlancer.mariadb.MariaDBSchema.MariaDBTable; -import sqlancer.mariadb.ast.MariaDBAggregate; -import sqlancer.mariadb.ast.MariaDBAggregate.MariaDBAggregateFunction; -import sqlancer.mariadb.ast.MariaDBColumnName; -import sqlancer.mariadb.ast.MariaDBExpression; -import sqlancer.mariadb.ast.MariaDBPostfixUnaryOperation; -import sqlancer.mariadb.ast.MariaDBPostfixUnaryOperation.MariaDBPostfixUnaryOperator; -import sqlancer.mariadb.ast.MariaDBSelectStatement; -import sqlancer.mariadb.ast.MariaDBSelectStatement.MariaDBSelectType; -import sqlancer.mariadb.ast.MariaDBText; -import sqlancer.mariadb.ast.MariaDBVisitor; -import sqlancer.mariadb.gen.MariaDBExpressionGenerator; - -public class MariaDBNoRECOracle extends NoRECBase implements TestOracle { - - private final MariaDBSchema s; - private static final int NOT_FOUND = -1; - - public MariaDBNoRECOracle(MariaDBGlobalState globalState) { - super(globalState); - this.s = globalState.getSchema(); - errors.add("is out of range"); - // regex - errors.add("unmatched parentheses"); - errors.add("nothing to repeat at offset"); - errors.add("missing )"); - errors.add("missing terminating ]"); - errors.add("range out of order in character class"); - errors.add("unrecognized character after "); - errors.add("Got error '(*VERB) not recognized or malformed"); - errors.add("must be followed by"); - errors.add("malformed number or name after"); - errors.add("digit expected after"); - } - - @Override - public void check() throws SQLException { - MariaDBTable randomTable = s.getRandomTable(); - List columns = randomTable.getColumns(); - MariaDBExpressionGenerator gen = new MariaDBExpressionGenerator(state.getRandomly()).setColumns(columns) - .setCon(con).setState(state.getState()); - MariaDBExpression randomWhereCondition = gen.getRandomExpression(); - List groupBys = Collections.emptyList(); // getRandomExpressions(columns); - int optimizedCount = getOptimizedQuery(randomTable, randomWhereCondition, groupBys); - int unoptimizedCount = getUnoptimizedQuery(randomTable, randomWhereCondition, groupBys); - if (optimizedCount == NOT_FOUND || unoptimizedCount == NOT_FOUND) { - throw new IgnoreMeException(); - } - if (optimizedCount != unoptimizedCount) { - state.getState().getLocalState().log(optimizedQueryString + ";\n" + unoptimizedQueryString + ";"); - throw new AssertionError(optimizedCount + " " + unoptimizedCount); - } - } - - private int getUnoptimizedQuery(MariaDBTable randomTable, MariaDBExpression randomWhereCondition, - List groupBys) throws SQLException { - MariaDBSelectStatement select = new MariaDBSelectStatement(); - select.setGroupByClause(groupBys); - MariaDBPostfixUnaryOperation isTrue = new MariaDBPostfixUnaryOperation(MariaDBPostfixUnaryOperator.IS_TRUE, - randomWhereCondition); - MariaDBText asText = new MariaDBText(isTrue, " as count", false); - select.setFetchColumns(Arrays.asList(asText)); - select.setFromTables(Arrays.asList(randomTable)); - select.setSelectType(MariaDBSelectType.ALL); - int secondCount = 0; - - unoptimizedQueryString = "SELECT SUM(count) FROM (" + MariaDBVisitor.asString(select) + ") as asdf"; - SQLQueryAdapter q = new SQLQueryAdapter(unoptimizedQueryString, errors); - try (SQLancerResultSet rs = q.executeAndGet(state)) { - if (rs == null) { - return NOT_FOUND; - } else { - while (rs.next()) { - secondCount = rs.getInt(1); - } - } - } - - return secondCount; - } - - private int getOptimizedQuery(MariaDBTable randomTable, MariaDBExpression randomWhereCondition, - List groupBys) throws SQLException { - MariaDBSelectStatement select = new MariaDBSelectStatement(); - select.setGroupByClause(groupBys); - MariaDBAggregate aggr = new MariaDBAggregate( - new MariaDBColumnName(new MariaDBColumn("*", MariaDBDataType.INT, false, 0)), - MariaDBAggregateFunction.COUNT); - select.setFetchColumns(Arrays.asList(aggr)); - select.setFromTables(Arrays.asList(randomTable)); - select.setWhereClause(randomWhereCondition); - select.setSelectType(MariaDBSelectType.ALL); - int firstCount; - optimizedQueryString = MariaDBVisitor.asString(select); - SQLQueryAdapter q = new SQLQueryAdapter(optimizedQueryString, errors); - try (SQLancerResultSet rs = q.executeAndGet(state)) { - if (rs == null) { - firstCount = NOT_FOUND; - } else { - rs.next(); - firstCount = rs.getInt(1); - } - } catch (Exception e) { - throw new AssertionError(optimizedQueryString, e); - } - return firstCount; - } - -} diff --git a/src/sqlancer/materialize/MaterializeCompoundDataType.java b/src/sqlancer/materialize/MaterializeCompoundDataType.java new file mode 100644 index 000000000..a4bddf7ca --- /dev/null +++ b/src/sqlancer/materialize/MaterializeCompoundDataType.java @@ -0,0 +1,46 @@ +package sqlancer.materialize; + +import java.util.Optional; + +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public final class MaterializeCompoundDataType { + + private final MaterializeDataType dataType; + private final MaterializeCompoundDataType elemType; + private final Integer size; + + private MaterializeCompoundDataType(MaterializeDataType dataType, MaterializeCompoundDataType elemType, + Integer size) { + this.dataType = dataType; + this.elemType = elemType; + this.size = size; + } + + public MaterializeDataType getDataType() { + return dataType; + } + + public MaterializeCompoundDataType getElemType() { + if (elemType == null) { + throw new AssertionError(); + } + return elemType; + } + + public Optional getSize() { + if (size == null) { + return Optional.empty(); + } else { + return Optional.of(size); + } + } + + public static MaterializeCompoundDataType create(MaterializeDataType type, int size) { + return new MaterializeCompoundDataType(type, null, size); + } + + public static MaterializeCompoundDataType create(MaterializeDataType type) { + return new MaterializeCompoundDataType(type, null, null); + } +} diff --git a/src/sqlancer/materialize/MaterializeExpectedValueVisitor.java b/src/sqlancer/materialize/MaterializeExpectedValueVisitor.java new file mode 100644 index 000000000..caa801268 --- /dev/null +++ b/src/sqlancer/materialize/MaterializeExpectedValueVisitor.java @@ -0,0 +1,162 @@ +package sqlancer.materialize; + +import sqlancer.materialize.ast.MaterializeAggregate; +import sqlancer.materialize.ast.MaterializeBetweenOperation; +import sqlancer.materialize.ast.MaterializeBinaryLogicalOperation; +import sqlancer.materialize.ast.MaterializeCastOperation; +import sqlancer.materialize.ast.MaterializeColumnValue; +import sqlancer.materialize.ast.MaterializeConstant; +import sqlancer.materialize.ast.MaterializeExpression; +import sqlancer.materialize.ast.MaterializeFunction; +import sqlancer.materialize.ast.MaterializeInOperation; +import sqlancer.materialize.ast.MaterializeLikeOperation; +import sqlancer.materialize.ast.MaterializeOrderByTerm; +import sqlancer.materialize.ast.MaterializePOSIXRegularExpression; +import sqlancer.materialize.ast.MaterializePostfixOperation; +import sqlancer.materialize.ast.MaterializePostfixText; +import sqlancer.materialize.ast.MaterializePrefixOperation; +import sqlancer.materialize.ast.MaterializeSelect; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeFromTable; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeSubquery; +import sqlancer.materialize.ast.MaterializeSimilarTo; + +public final class MaterializeExpectedValueVisitor implements MaterializeVisitor { + + private final StringBuilder sb = new StringBuilder(); + private static final int NR_TABS = 0; + + private void print(MaterializeExpression expr) { + MaterializeToStringVisitor v = new MaterializeToStringVisitor(); + v.visit(expr); + for (int i = 0; i < NR_TABS; i++) { + sb.append("\t"); + } + sb.append(v.get()); + sb.append(" -- "); + sb.append(expr.getExpectedValue()); + sb.append("\n"); + } + + @Override + public void visit(MaterializeConstant constant) { + print(constant); + } + + @Override + public void visit(MaterializePostfixOperation op) { + print(op); + visit(op.getExpression()); + } + + public String get() { + return sb.toString(); + } + + @Override + public void visit(MaterializeColumnValue c) { + print(c); + } + + @Override + public void visit(MaterializePrefixOperation op) { + print(op); + visit(op.getExpression()); + } + + @Override + public void visit(MaterializeSelect op) { + visit(op.getWhereClause()); + } + + @Override + public void visit(MaterializeOrderByTerm op) { + + } + + @Override + public void visit(MaterializeFunction f) { + print(f); + for (int i = 0; i < f.getArguments().length; i++) { + visit(f.getArguments()[i]); + } + } + + @Override + public void visit(MaterializeCastOperation cast) { + print(cast); + visit(cast.getExpression()); + } + + @Override + public void visit(MaterializeBetweenOperation op) { + print(op); + visit(op.getExpr()); + visit(op.getLeft()); + visit(op.getRight()); + } + + @Override + public void visit(MaterializeInOperation op) { + print(op); + visit(op.getExpr()); + for (MaterializeExpression right : op.getListElements()) { + visit(right); + } + } + + @Override + public void visit(MaterializePostfixText op) { + print(op); + visit(op.getExpr()); + } + + @Override + public void visit(MaterializeAggregate op) { + print(op); + for (MaterializeExpression expr : op.getArgs()) { + visit(expr); + } + } + + @Override + public void visit(MaterializeSimilarTo op) { + print(op); + visit(op.getString()); + visit(op.getSimilarTo()); + if (op.getEscapeCharacter() != null) { + visit(op.getEscapeCharacter()); + } + } + + @Override + public void visit(MaterializePOSIXRegularExpression op) { + print(op); + visit(op.getString()); + visit(op.getRegex()); + } + + @Override + public void visit(MaterializeFromTable from) { + print(from); + } + + @Override + public void visit(MaterializeSubquery subquery) { + print(subquery); + } + + @Override + public void visit(MaterializeBinaryLogicalOperation op) { + print(op); + visit(op.getLeft()); + visit(op.getRight()); + } + + @Override + public void visit(MaterializeLikeOperation op) { + print(op); + visit(op.getLeft()); + visit(op.getRight()); + } + +} diff --git a/src/sqlancer/materialize/MaterializeGlobalState.java b/src/sqlancer/materialize/MaterializeGlobalState.java new file mode 100644 index 000000000..77cbe5d14 --- /dev/null +++ b/src/sqlancer/materialize/MaterializeGlobalState.java @@ -0,0 +1,292 @@ +package sqlancer.materialize; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; + +public class MaterializeGlobalState extends SQLGlobalState { + + public static final char IMMUTABLE = 'i'; + public static final char STABLE = 's'; + public static final char VOLATILE = 'v'; + + private List operators = Collections.emptyList(); + private List collates = Collections.emptyList(); + private List opClasses = Collections.emptyList(); + private List tableAccessMethods = Collections.emptyList(); + // store and allow filtering by function volatility classifications + private final Map functionsAndTypes = new HashMap<>(); + private List allowedFunctionTypes = Arrays.asList(IMMUTABLE, STABLE, VOLATILE); + + @Override + public void setConnection(SQLConnection con) { + super.setConnection(con); + try { + this.opClasses = getOpclasses(); + this.operators = getOperators(getConnection()); + this.collates = getCollnames(getConnection()); + this.tableAccessMethods = getTableAccessMethods(getConnection()); + } catch (SQLException e) { + throw new AssertionError(e); + } + } + + private List getCollnames(SQLConnection con) throws SQLException { + List collNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s + .executeQuery("SELECT collname FROM pg_collation WHERE collname LIKE '%utf8' or collname = 'C';")) { + while (rs.next()) { + collNames.add(rs.getString(1)); + } + } + } + return collNames; + } + + private List getOpclasses() throws SQLException { + List opClasses = new ArrayList<>(); + // select opcname FROM pg_opclass; + // ERROR: unknown catalog item 'pg_opclass' + opClasses.add("array_ops"); + opClasses.add("array_ops"); + opClasses.add("bit_ops"); + opClasses.add("bool_ops"); + opClasses.add("bpchar_ops"); + opClasses.add("bpchar_ops"); + opClasses.add("bytea_ops"); + opClasses.add("char_ops"); + opClasses.add("char_ops"); + opClasses.add("cidr_ops"); + opClasses.add("cidr_ops"); + opClasses.add("date_ops"); + opClasses.add("date_ops"); + opClasses.add("float4_ops"); + opClasses.add("float4_ops"); + opClasses.add("float8_ops"); + opClasses.add("float8_ops"); + opClasses.add("inet_ops"); + opClasses.add("inet_ops"); + opClasses.add("inet_ops"); + opClasses.add("inet_ops"); + opClasses.add("int2_ops"); + opClasses.add("int2_ops"); + opClasses.add("int4_ops"); + opClasses.add("int4_ops"); + opClasses.add("int8_ops"); + opClasses.add("int8_ops"); + opClasses.add("interval_ops"); + opClasses.add("interval_ops"); + opClasses.add("macaddr_ops"); + opClasses.add("macaddr_ops"); + opClasses.add("macaddr8_ops"); + opClasses.add("macaddr8_ops"); + opClasses.add("name_ops"); + opClasses.add("name_ops"); + opClasses.add("numeric_ops"); + opClasses.add("numeric_ops"); + opClasses.add("oid_ops"); + opClasses.add("oid_ops"); + opClasses.add("oidvector_ops"); + opClasses.add("oidvector_ops"); + opClasses.add("record_ops"); + opClasses.add("record_image_ops"); + opClasses.add("text_ops"); + opClasses.add("text_ops"); + opClasses.add("time_ops"); + opClasses.add("time_ops"); + opClasses.add("timestamptz_ops"); + opClasses.add("timestamptz_ops"); + opClasses.add("timetz_ops"); + opClasses.add("timetz_ops"); + opClasses.add("varbit_ops"); + opClasses.add("varchar_ops"); + opClasses.add("varchar_ops"); + opClasses.add("timestamp_ops"); + opClasses.add("timestamp_ops"); + opClasses.add("text_pattern_ops"); + opClasses.add("varchar_pattern_ops"); + opClasses.add("bpchar_pattern_ops"); + opClasses.add("money_ops"); + opClasses.add("bool_ops"); + opClasses.add("bytea_ops"); + opClasses.add("tid_ops"); + opClasses.add("xid_ops"); + opClasses.add("cid_ops"); + opClasses.add("tid_ops"); + opClasses.add("text_pattern_ops"); + opClasses.add("varchar_pattern_ops"); + opClasses.add("bpchar_pattern_ops"); + opClasses.add("aclitem_ops"); + opClasses.add("box_ops"); + opClasses.add("point_ops"); + opClasses.add("text_pattern_ops"); + opClasses.add("varchar_pattern_ops"); + opClasses.add("bpchar_pattern_ops"); + opClasses.add("money_ops"); + opClasses.add("bool_ops"); + opClasses.add("bytea_ops"); + opClasses.add("tid_ops"); + opClasses.add("xid_ops"); + opClasses.add("cid_ops"); + opClasses.add("tid_ops"); + opClasses.add("text_pattern_ops"); + opClasses.add("varchar_pattern_ops"); + opClasses.add("bpchar_pattern_ops"); + opClasses.add("aclitem_ops"); + opClasses.add("box_ops"); + opClasses.add("point_ops"); + opClasses.add("poly_ops"); + opClasses.add("circle_ops"); + opClasses.add("array_ops"); + opClasses.add("uuid_ops"); + opClasses.add("uuid_ops"); + opClasses.add("pg_lsn_ops"); + opClasses.add("pg_lsn_ops"); + opClasses.add("enum_ops"); + opClasses.add("enum_ops"); + opClasses.add("tsvector_ops"); + opClasses.add("tsvector_ops"); + opClasses.add("tsvector_ops"); + opClasses.add("tsquery_ops"); + opClasses.add("tsquery_ops"); + opClasses.add("range_ops"); + opClasses.add("range_ops"); + opClasses.add("range_ops"); + opClasses.add("range_ops"); + opClasses.add("box_ops"); + opClasses.add("quad_point_ops"); + opClasses.add("kd_point_ops"); + opClasses.add("text_ops"); + opClasses.add("poly_ops"); + opClasses.add("jsonb_ops"); + opClasses.add("jsonb_ops"); + opClasses.add("jsonb_ops"); + opClasses.add("jsonb_path_ops"); + opClasses.add("bytea_minmax_ops"); + opClasses.add("char_minmax_ops"); + opClasses.add("name_minmax_ops"); + opClasses.add("int8_minmax_ops"); + opClasses.add("int2_minmax_ops"); + opClasses.add("int4_minmax_ops"); + opClasses.add("text_minmax_ops"); + opClasses.add("oid_minmax_ops"); + opClasses.add("tid_minmax_ops"); + opClasses.add("float4_minmax_ops"); + opClasses.add("float8_minmax_ops"); + opClasses.add("macaddr_minmax_ops"); + opClasses.add("macaddr8_minmax_ops"); + opClasses.add("inet_minmax_ops"); + opClasses.add("inet_inclusion_ops"); + opClasses.add("bpchar_minmax_ops"); + opClasses.add("time_minmax_ops"); + opClasses.add("date_minmax_ops"); + opClasses.add("timestamp_minmax_ops"); + opClasses.add("timestamptz_minmax_ops"); + opClasses.add("interval_minmax_ops"); + opClasses.add("timetz_minmax_ops"); + opClasses.add("bit_minmax_ops"); + opClasses.add("varbit_minmax_ops"); + opClasses.add("numeric_minmax_ops"); + opClasses.add("uuid_minmax_ops"); + opClasses.add("range_inclusion_ops"); + opClasses.add("pg_lsn_minmax_ops"); + opClasses.add("box_inclusion_ops"); + return opClasses; + } + + private List getOperators(SQLConnection con) throws SQLException { + List operators = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery("SELECT oprname FROM pg_operator;")) { + while (rs.next()) { + operators.add(rs.getString(1)); + } + } + } + return operators; + } + + private List getTableAccessMethods(SQLConnection con) throws SQLException { + List tableAccessMethods = new ArrayList<>(); + try (Statement s = con.createStatement()) { + /* + * pg_am includes both index and table access methods so we need to filter with amtype = 't' + */ + try (ResultSet rs = s.executeQuery("SELECT amname FROM pg_am WHERE amtype = 't';")) { + while (rs.next()) { + tableAccessMethods.add(rs.getString(1)); + } + } + } + return tableAccessMethods; + } + + public List getOperators() { + return operators; + } + + public String getRandomOperator() { + return Randomly.fromList(operators); + } + + public List getCollates() { + return collates; + } + + public String getRandomCollate() { + return Randomly.fromList(collates); + } + + public List getOpClasses() { + return opClasses; + } + + public String getRandomOpclass() { + return Randomly.fromList(opClasses); + } + + public List getTableAccessMethods() { + return tableAccessMethods; + } + + public String getRandomTableAccessMethod() { + return Randomly.fromList(tableAccessMethods); + } + + @Override + public MaterializeSchema readSchema() throws SQLException { + return MaterializeSchema.fromConnection(getConnection(), getDatabaseName()); + } + + public void addFunctionAndType(String functionName, Character functionType) { + this.functionsAndTypes.put(functionName, functionType); + } + + public Map getFunctionsAndTypes() { + return this.functionsAndTypes; + } + + public void setAllowedFunctionTypes(List types) { + this.allowedFunctionTypes = types; + } + + public void setDefaultAllowedFunctionTypes() { + this.allowedFunctionTypes = Arrays.asList(IMMUTABLE, STABLE, VOLATILE); + } + + public List getAllowedFunctionTypes() { + return this.allowedFunctionTypes; + } + +} diff --git a/src/sqlancer/materialize/MaterializeOptions.java b/src/sqlancer/materialize/MaterializeOptions.java new file mode 100644 index 000000000..104dc9b51 --- /dev/null +++ b/src/sqlancer/materialize/MaterializeOptions.java @@ -0,0 +1,43 @@ +package sqlancer.materialize; + +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; + +@Parameters(separators = "=", commandDescription = "Materialize (default port: " + MaterializeOptions.DEFAULT_PORT + + ", default host: " + MaterializeOptions.DEFAULT_HOST + ", default user: " + MaterializeOptions.DEFAULT_USER + + ")") +public class MaterializeOptions implements DBMSSpecificOptions { + public static final String DEFAULT_HOST = "localhost"; + public static final String DEFAULT_USER = "materialize"; + public static final int DEFAULT_PORT = 6875; + + @Parameter(names = "--bulk-insert", description = "Specifies whether INSERT statements should be issued in bulk", arity = 1) + public boolean allowBulkInsert; + + @Parameter(names = "--oracle", description = "Specifies which test oracle should be used for Materialize") + public List oracle = Arrays.asList(MaterializeOracleFactory.QUERY_PARTITIONING); + + @Parameter(names = "--test-collations", description = "Specifies whether to test different collations", arity = 1) + public boolean testCollations = true; + + @Parameter(names = "--set-max-tables-mvs", description = "Specifies whether to set the maximum number of tables and materialized views intiially", arity = 1) + public boolean setMaxTablesMVs; + + @Parameter(names = "--connection-url", description = "Specifies the URL for connecting to the Materialize server", arity = 1) + public String connectionURL = String.format("postgresql://%s@%s:%d/test", MaterializeOptions.DEFAULT_USER, + MaterializeOptions.DEFAULT_HOST, MaterializeOptions.DEFAULT_PORT); + + @Parameter(names = "--extensions", description = "Specifies a comma-separated list of extension names to be created in each test database", arity = 1) + public String extensions = ""; + + @Override + public List getTestOracleFactory() { + return oracle; + } + +} diff --git a/src/sqlancer/materialize/MaterializeOracleFactory.java b/src/sqlancer/materialize/MaterializeOracleFactory.java new file mode 100644 index 000000000..c652fd56a --- /dev/null +++ b/src/sqlancer/materialize/MaterializeOracleFactory.java @@ -0,0 +1,70 @@ +package sqlancer.materialize; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.materialize.gen.MaterializeCommon; +import sqlancer.materialize.gen.MaterializeExpressionGenerator; +import sqlancer.materialize.oracle.MaterializePivotedQuerySynthesisOracle; +import sqlancer.materialize.oracle.tlp.MaterializeTLPAggregateOracle; +import sqlancer.materialize.oracle.tlp.MaterializeTLPHavingOracle; + +public enum MaterializeOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create(MaterializeGlobalState globalState) throws SQLException { + MaterializeExpressionGenerator gen = new MaterializeExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(MaterializeCommon.getCommonExpressionErrors()) + .with(MaterializeCommon.getCommonFetchErrors()).with("canceling statement due to statement timeout") + .build(); + return new NoRECOracle<>(globalState, gen, errors); + } + }, + PQS { + @Override + public TestOracle create(MaterializeGlobalState globalState) throws SQLException { + return new MaterializePivotedQuerySynthesisOracle(globalState); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + }, + WHERE { + @Override + public TestOracle create(MaterializeGlobalState globalState) throws SQLException { + MaterializeExpressionGenerator gen = new MaterializeExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors() + .with(MaterializeCommon.getCommonExpressionErrors()).with(MaterializeCommon.getCommonFetchErrors()) + .build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + HAVING { + @Override + public TestOracle create(MaterializeGlobalState globalState) throws SQLException { + return new MaterializeTLPHavingOracle(globalState); + } + + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(MaterializeGlobalState globalState) throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(WHERE.create(globalState)); + oracles.add(HAVING.create(globalState)); + oracles.add(new MaterializeTLPAggregateOracle(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + }; + +} diff --git a/src/sqlancer/materialize/MaterializeProvider.java b/src/sqlancer/materialize/MaterializeProvider.java new file mode 100644 index 000000000..e7bdb4c4f --- /dev/null +++ b/src/sqlancer/materialize/MaterializeProvider.java @@ -0,0 +1,345 @@ +package sqlancer.materialize; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.StringReader; +import java.net.URI; +import java.net.URISyntaxException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.materialize.gen.MaterializeDeleteGenerator; +import sqlancer.materialize.gen.MaterializeDropIndexGenerator; +import sqlancer.materialize.gen.MaterializeIndexGenerator; +import sqlancer.materialize.gen.MaterializeInsertGenerator; +import sqlancer.materialize.gen.MaterializeTableGenerator; +import sqlancer.materialize.gen.MaterializeUpdateGenerator; +import sqlancer.materialize.gen.MaterializeViewGenerator; + +// EXISTS +// IN +@AutoService(DatabaseProvider.class) +public class MaterializeProvider extends SQLProviderAdapter { + + /** + * Generate only data types and expressions that are understood by PQS. + */ + public static boolean generateOnlyKnown; + + protected String entryURL; + protected String username; + protected String password; + protected String entryPath; + protected String host; + protected int port; + protected String testURL; + protected String databaseName; + protected String createDatabaseCommand; + protected String extensionsList; + + public MaterializeProvider() { + super(MaterializeGlobalState.class, MaterializeOptions.class); + } + + protected MaterializeProvider(Class globalClass, Class optionClass) { + super(globalClass, optionClass); + } + + public enum Action implements AbstractAction { + DELETE(MaterializeDeleteGenerator::create), // + DROP_INDEX(MaterializeDropIndexGenerator::create), // + INSERT(MaterializeInsertGenerator::insert), // + UPDATE(MaterializeUpdateGenerator::create), // + CREATE_INDEX(MaterializeIndexGenerator::generate), // + CREATE_VIEW(MaterializeViewGenerator::create); + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(MaterializeGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + + protected static int mapActions(MaterializeGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + int nrPerformed; + switch (a) { + case CREATE_INDEX: + nrPerformed = r.getInteger(0, 3); + break; + case DROP_INDEX: + nrPerformed = r.getInteger(0, 5); + break; + case DELETE: + nrPerformed = r.getInteger(0, 5); + break; + case CREATE_VIEW: + nrPerformed = r.getInteger(0, 2); + break; + case UPDATE: + nrPerformed = r.getInteger(0, 10); + break; + case INSERT: + nrPerformed = r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + break; + default: + throw new AssertionError(a); + } + return nrPerformed; + + } + + @Override + public void generateDatabase(MaterializeGlobalState globalState) throws Exception { + readFunctions(globalState); + createTables(globalState, Randomly.fromOptions(4, 5, 6)); + prepareTables(globalState); + + extensionsList = globalState.getDbmsSpecificOptions().extensions; + if (!extensionsList.isEmpty()) { + String[] extensionNames = extensionsList.split(","); + + /* + * To avoid of a test interference with an extension objects, create them in a separate schema. Of course, + * they must be truly relocatable. + */ + globalState.executeStatement(new SQLQueryAdapter("CREATE SCHEMA extensions;", true)); + for (int i = 0; i < extensionNames.length; i++) { + globalState.executeStatement(new SQLQueryAdapter( + "CREATE EXTENSION " + extensionNames[i] + " WITH SCHEMA extensions;", true)); + } + } + } + + @Override + public SQLConnection createDatabase(MaterializeGlobalState globalState) throws SQLException { + if (globalState.getDbmsSpecificOptions().getTestOracleFactory().stream() + .anyMatch((o) -> o == MaterializeOracleFactory.PQS)) { + generateOnlyKnown = true; + } + + username = globalState.getOptions().getUserName(); + password = globalState.getOptions().getPassword(); + host = globalState.getOptions().getHost(); + port = globalState.getOptions().getPort(); + entryPath = "/test"; + entryURL = globalState.getDbmsSpecificOptions().connectionURL; + // trim URL to exclude "jdbc:" + if (entryURL.startsWith("jdbc:")) { + entryURL = entryURL.substring(5); + } + String entryDatabaseName = entryPath.substring(1); + databaseName = globalState.getDatabaseName(); + + try { + URI uri = new URI(entryURL); + String userInfoURI = uri.getUserInfo(); + String pathURI = uri.getPath(); + if (userInfoURI != null) { + // username and password specified in URL take precedence + if (userInfoURI.contains(":")) { + String[] userInfo = userInfoURI.split(":", 2); + username = userInfo[0]; + password = userInfo[1]; + } else { + username = userInfoURI; + password = null; + } + int userInfoIndex = entryURL.indexOf(userInfoURI); + String preUserInfo = entryURL.substring(0, userInfoIndex); + String postUserInfo = entryURL.substring(userInfoIndex + userInfoURI.length() + 1); + entryURL = preUserInfo + postUserInfo; + } + if (pathURI != null) { + entryPath = pathURI; + } + if (host == null) { + host = uri.getHost(); + } + if (port == MainOptions.NO_SET_PORT) { + port = uri.getPort(); + } + entryURL = String.format("%s://%s:%d/%s", uri.getScheme(), host, port, entryDatabaseName); + } catch (URISyntaxException e) { + throw new AssertionError(e); + } + Connection con = DriverManager.getConnection("jdbc:" + entryURL, username, password); + globalState.getState().logStatement(String.format("\\c %s;", entryDatabaseName)); + globalState.getState().logStatement("DROP DATABASE IF EXISTS " + databaseName); + createDatabaseCommand = getCreateDatabaseCommand(globalState); + globalState.getState().logStatement(createDatabaseCommand); + try (Statement s = con.createStatement()) { + s.execute("DROP DATABASE IF EXISTS " + databaseName); + } + try (Statement s = con.createStatement()) { + s.execute(createDatabaseCommand); + } + con.close(); + if (globalState.getDbmsSpecificOptions().setMaxTablesMVs) { + Connection conMzSystem = DriverManager.getConnection("jdbc:postgresql://localhost:6877/materialize", + "mz_system", "materialize"); + try (Statement s = conMzSystem.createStatement()) { + s.execute("ALTER SYSTEM SET max_tables TO 1000"); + } + try (Statement s = conMzSystem.createStatement()) { + s.execute("ALTER SYSTEM SET max_materialized_views TO 1000"); + } + conMzSystem.close(); + } + int databaseIndex = entryURL.indexOf(entryDatabaseName); + String preDatabaseName = entryURL.substring(0, databaseIndex); + String postDatabaseName = entryURL.substring(databaseIndex + entryDatabaseName.length()); + testURL = preDatabaseName + databaseName + postDatabaseName; + globalState.getState().logStatement(String.format("\\c %s;", databaseName)); + + con = DriverManager.getConnection("jdbc:" + testURL, username, password); + try (Statement s = con.createStatement()) { + // Serializable transaction isolation is much faster than Strict + // Serializable and should guarantee enough for SQLancer: + // https://materialize.com/docs/overview/isolation-level/ + s.execute("SET transaction_isolation = 'SERIALIZABLE'"); + // Make sure tables still are visible immediately by not using an + // index for them, see + // https://github.com/MaterializeInc/materialize/issues/19431 + s.execute("SET auto_route_introspection_queries = false"); + } + return new SQLConnection(con); + } + + protected void readFunctions(MaterializeGlobalState globalState) throws SQLException { + // ERROR: column "provolatile" does not exist + SQLQueryAdapter query = new SQLQueryAdapter("SELECT proname, 1 FROM pg_proc;"); + SQLancerResultSet rs = query.executeAndGet(globalState); + while (rs.next()) { + String functionName = rs.getString(1); + Character functionType = rs.getString(2).charAt(0); + globalState.addFunctionAndType(functionName, functionType); + } + } + + protected void createTables(MaterializeGlobalState globalState, int numTables) throws Exception { + int existingTables = globalState.getSchema().getDatabaseTables().size(); + int createdTables = 0; + int nextTableIndex = existingTables; + while (existingTables + createdTables < numTables) { + try { + String tableName = DBMSCommon.createTableName(nextTableIndex++); + SQLQueryAdapter createTable = MaterializeTableGenerator.generate(tableName, globalState.getSchema(), + generateOnlyKnown, globalState); + if (globalState.executeStatement(createTable)) { + createdTables++; + } + } catch (IgnoreMeException e) { + + } + } + } + + protected void prepareTables(MaterializeGlobalState globalState) throws Exception { + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + MaterializeProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + globalState.executeStatement(new SQLQueryAdapter("COMMIT", true)); + globalState.executeStatement(new SQLQueryAdapter("SET SESSION statement_timeout = 5000;\n")); + } + + private String getCreateDatabaseCommand(MaterializeGlobalState state) { + StringBuilder sb = new StringBuilder(); + sb.append("CREATE DATABASE " + databaseName + " "); + if (Randomly.getBoolean() && ((MaterializeOptions) state.getDbmsSpecificOptions()).testCollations) { + for (String lc : Arrays.asList("LC_COLLATE", "LC_CTYPE")) { + if (!state.getCollates().isEmpty() && Randomly.getBoolean()) { + sb.append(String.format(" %s = '%s'", lc, Randomly.fromList(state.getCollates()))); + } + } + } + return sb.toString(); + } + + @Override + public String getDBMSName() { + return "materialize"; + } + + @Override + public String getQueryPlan(String selectStr, MaterializeGlobalState globalState) throws Exception { + String queryPlan = ""; + String explainQuery = "EXPLAIN OPTIMIZED PLAN FOR " + selectStr; + if (globalState.getOptions().logEachSelect()) { + globalState.getLogger().writeCurrent(explainQuery); + try { + globalState.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + SQLQueryAdapter q = new SQLQueryAdapter(explainQuery); + boolean afterProjection = false; // Remove the concrete expression after each Projection operator + SQLancerResultSet rs = q.executeAndGet(globalState); + if (rs != null) { + while (rs.next()) { + String line; + BufferedReader bufReader = new BufferedReader(new StringReader(rs.getString(1))); + while ((line = bufReader.readLine()) != null) { + String targetQueryPlan = line.trim() + ";"; // Unify format + if (targetQueryPlan.startsWith("Explained Query:")) { + continue; + } + if (afterProjection) { + afterProjection = false; + continue; + } + if (targetQueryPlan.startsWith("Project")) { + afterProjection = true; + } + // Remove all concrete expressions by keywords + if (targetQueryPlan.contains(">") || targetQueryPlan.contains("<") || targetQueryPlan.contains("=") + || targetQueryPlan.contains("*") || targetQueryPlan.contains("+") + || targetQueryPlan.contains("'")) { + continue; + } + queryPlan += targetQueryPlan; + } + } + } + + return queryPlan; + } + + @Override + protected double[] initializeWeightedAverageReward() { + return new double[Action.values().length]; + } + + @Override + protected void executeMutator(int index, MaterializeGlobalState globalState) throws Exception { + SQLQueryAdapter queryMutateTable = Action.values()[index].getQuery(globalState); + globalState.executeStatement(queryMutateTable); + } +} diff --git a/src/sqlancer/materialize/MaterializeSchema.java b/src/sqlancer/materialize/MaterializeSchema.java new file mode 100644 index 000000000..1762c5cb9 --- /dev/null +++ b/src/sqlancer/materialize/MaterializeSchema.java @@ -0,0 +1,330 @@ +package sqlancer.materialize; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLIntegrityConstraintViolationException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.postgresql.util.PSQLException; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.DBMSCommon; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractRowValue; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; +import sqlancer.materialize.MaterializeSchema.MaterializeTable.TableType; +import sqlancer.materialize.ast.MaterializeConstant; + +public class MaterializeSchema extends AbstractSchema { + + private final String databaseName; + private final List indexNames; + + public List getIndexNames() { + return indexNames; + } + + public enum MaterializeDataType { + INT, BOOLEAN, TEXT, DECIMAL, FLOAT, REAL, BIT; + + public static MaterializeDataType getRandomType() { + List dataTypes = new ArrayList<>(Arrays.asList(values())); + if (MaterializeProvider.generateOnlyKnown) { + dataTypes.remove(MaterializeDataType.DECIMAL); + dataTypes.remove(MaterializeDataType.FLOAT); + dataTypes.remove(MaterializeDataType.REAL); + dataTypes.remove(MaterializeDataType.BIT); + } + return Randomly.fromList(dataTypes); + } + } + + public static class MaterializeColumn extends AbstractTableColumn { + + public MaterializeColumn(String name, MaterializeDataType columnType) { + super(name, null, columnType); + } + + public static MaterializeColumn createDummy(String name) { + return new MaterializeColumn(name, MaterializeDataType.INT); + } + + } + + public static class MaterializeTables extends AbstractTables { + + public MaterializeTables(List tables) { + super(tables); + } + + public MaterializeRowValue getRandomRowValue(SQLConnection con) throws SQLException { + String randomRow = String.format("SELECT %s FROM %s LIMIT 1", columnNamesAsString( + c -> c.getTable().getName() + "." + c.getName() + " AS " + c.getTable().getName() + c.getName()), + tableNamesAsString()); + Map values = new HashMap<>(); + try (Statement s = con.createStatement()) { + ResultSet randomRowValues = s.executeQuery(randomRow); + if (!randomRowValues.next()) { + throw new AssertionError("could not find random row! " + randomRow + "\n"); + } + for (int i = 0; i < getColumns().size(); i++) { + MaterializeColumn column = getColumns().get(i); + int columnIndex = randomRowValues.findColumn(column.getTable().getName() + column.getName()); + assert columnIndex == i + 1; + MaterializeConstant constant; + if (randomRowValues.getString(columnIndex) == null) { + constant = MaterializeConstant.createNullConstant(); + } else { + switch (column.getType()) { + case INT: + constant = MaterializeConstant.createIntConstant(randomRowValues.getLong(columnIndex)); + break; + case BOOLEAN: + constant = MaterializeConstant + .createBooleanConstant(randomRowValues.getBoolean(columnIndex)); + break; + case TEXT: + constant = MaterializeConstant.createTextConstant(randomRowValues.getString(columnIndex)); + break; + default: + throw new IgnoreMeException(); + } + } + values.put(column, constant); + } + assert !randomRowValues.next(); + return new MaterializeRowValue(this, values); + } catch (PSQLException e) { + throw new IgnoreMeException(); + } + + } + + } + + public static MaterializeDataType getColumnType(String typeString) { + switch (typeString) { + case "smallint": + case "integer": + case "bigint": + return MaterializeDataType.INT; + case "boolean": + return MaterializeDataType.BOOLEAN; + case "text": + case "character": + case "character varying": + case "name": + case "regclass": + return MaterializeDataType.TEXT; + case "numeric": + return MaterializeDataType.DECIMAL; + case "double precision": + return MaterializeDataType.FLOAT; + case "real": + return MaterializeDataType.REAL; + case "bit": + return MaterializeDataType.BIT; + default: + throw new AssertionError(typeString); + } + } + + public static class MaterializeRowValue + extends AbstractRowValue { + + protected MaterializeRowValue(MaterializeTables tables, Map values) { + super(tables, values); + } + + } + + public static class MaterializeTable + extends AbstractRelationalTable { + + public enum TableType { + STANDARD, TEMPORARY + } + + private final TableType tableType; + private final List statistics; + private final boolean isInsertable; + + public MaterializeTable(String tableName, List columns, List indexes, + TableType tableType, List statistics, boolean isView, + boolean isInsertable) { + super(tableName, columns, indexes, isView); + this.statistics = statistics; + this.isInsertable = isInsertable; + this.tableType = tableType; + } + + public List getStatistics() { + return statistics; + } + + public TableType getTableType() { + return tableType; + } + + public boolean isInsertable() { + return isInsertable; + } + + } + + public static final class MaterializeStatisticsObject { + private final String name; + + public MaterializeStatisticsObject(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + public static final class MaterializeIndex extends TableIndex { + + private MaterializeIndex(String indexName) { + super(indexName); + } + + public static MaterializeIndex create(String indexName) { + return new MaterializeIndex(indexName); + } + + @Override + public String getIndexName() { + if (super.getIndexName().contentEquals("PRIMARY")) { + return "`PRIMARY`"; + } else { + return super.getIndexName(); + } + } + + } + + public static MaterializeSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + try { + List databaseTables = new ArrayList<>(); + List indexNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + // ERROR: column "is_insertable_into" does not exist + try (ResultSet rs = s.executeQuery( + "SELECT table_name, table_schema, table_type FROM information_schema.tables WHERE table_schema='public' OR table_schema LIKE 'pg_temp_%' ORDER BY table_name;")) { + while (rs.next()) { + String tableName = rs.getString("table_name"); + String tableTypeSchema = rs.getString("table_schema"); + boolean isInsertable = true; + String type = rs.getString("table_type"); + boolean isView = type.equals("VIEW") || type.equals("MATERIALIZED VIEW"); + if (isView) { + isInsertable = false; + } + MaterializeTable.TableType tableType = getTableType(tableTypeSchema); + List databaseColumns = getTableColumns(con, tableName); + List indexes = getIndexes(con, tableName); + List statistics = getStatistics(con); + MaterializeTable t = new MaterializeTable(tableName, databaseColumns, indexes, tableType, + statistics, isView, isInsertable); + for (MaterializeColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + } + } + } + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String.format( + "SELECT mz_indexes.name, mz_databases.name FROM mz_indexes JOIN mz_relations ON mz_indexes.on_id = mz_relations.id JOIN mz_schemas ON mz_relations.schema_id = mz_schemas.id JOIN mz_databases ON mz_schemas.database_id = mz_databases.id WHERE mz_databases.name = '%s';", + databaseName))) { + while (rs.next()) { + String name = rs.getString(1); + indexNames.add(name); + } + } + } + return new MaterializeSchema(databaseTables, databaseName, indexNames); + } catch (SQLIntegrityConstraintViolationException e) { + throw new AssertionError(e); + } + } + + protected static List getStatistics(SQLConnection con) throws SQLException { + return new ArrayList<>(); + } + + protected static MaterializeTable.TableType getTableType(String tableTypeStr) throws AssertionError { + MaterializeTable.TableType tableType; + if (tableTypeStr.contentEquals("public")) { + tableType = TableType.STANDARD; + } else if (tableTypeStr.startsWith("pg_temp")) { + tableType = TableType.TEMPORARY; + } else { + throw new AssertionError(tableTypeStr); + } + return tableType; + } + + protected static List getIndexes(SQLConnection con, String tableName) throws SQLException { + List indexes = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String + // org.postgresql.util.PSQLException: ERROR: unknown catalog item 'pg_indexes' + .format("SELECT c.relname as indexname FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_catalog.pg_index i ON i.indexrelid = c.oid LEFT JOIN pg_catalog.pg_class c2 ON i.indrelid = c2.oid WHERE c.relkind IN ('i','I','') AND n.nspname <> 'pg_catalog' AND n.nspname !~ '^pg_toast' AND n.nspname <> 'information_schema' AND c2.relname = '%s' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY indexname;", + tableName))) { + while (rs.next()) { + String indexName = rs.getString("indexname"); + if (DBMSCommon.matchesIndexName(indexName)) { + indexes.add(MaterializeIndex.create(indexName)); + } + } + } + } + return indexes; + } + + protected static List getTableColumns(SQLConnection con, String tableName) throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s + .executeQuery("select column_name, data_type from INFORMATION_SCHEMA.COLUMNS where table_name = '" + + tableName + "' ORDER BY column_name")) { + while (rs.next()) { + String columnName = rs.getString("column_name"); + String dataType = rs.getString("data_type"); + MaterializeColumn c = new MaterializeColumn(columnName, getColumnType(dataType)); + columns.add(c); + } + } + } + return columns; + } + + public MaterializeSchema(List databaseTables, String databaseName, List indexNames) { + super(databaseTables); + this.databaseName = databaseName; + this.indexNames = indexNames; + } + + public MaterializeTables getRandomTableNonEmptyTables() { + return new MaterializeTables(Randomly.nonEmptySubset(getDatabaseTables())); + } + + public String getDatabaseName() { + return databaseName; + } + +} diff --git a/src/sqlancer/materialize/MaterializeToStringVisitor.java b/src/sqlancer/materialize/MaterializeToStringVisitor.java new file mode 100644 index 000000000..285be276f --- /dev/null +++ b/src/sqlancer/materialize/MaterializeToStringVisitor.java @@ -0,0 +1,335 @@ +package sqlancer.materialize; + +import java.util.Optional; + +import sqlancer.Randomly; +import sqlancer.common.visitor.BinaryOperation; +import sqlancer.common.visitor.ToStringVisitor; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.ast.MaterializeAggregate; +import sqlancer.materialize.ast.MaterializeBetweenOperation; +import sqlancer.materialize.ast.MaterializeBinaryLogicalOperation; +import sqlancer.materialize.ast.MaterializeCastOperation; +import sqlancer.materialize.ast.MaterializeColumnValue; +import sqlancer.materialize.ast.MaterializeConstant; +import sqlancer.materialize.ast.MaterializeExpression; +import sqlancer.materialize.ast.MaterializeFunction; +import sqlancer.materialize.ast.MaterializeInOperation; +import sqlancer.materialize.ast.MaterializeJoin; +import sqlancer.materialize.ast.MaterializeJoin.MaterializeJoinType; +import sqlancer.materialize.ast.MaterializeLikeOperation; +import sqlancer.materialize.ast.MaterializeOrderByTerm; +import sqlancer.materialize.ast.MaterializePOSIXRegularExpression; +import sqlancer.materialize.ast.MaterializePostfixOperation; +import sqlancer.materialize.ast.MaterializePostfixText; +import sqlancer.materialize.ast.MaterializePrefixOperation; +import sqlancer.materialize.ast.MaterializeSelect; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeFromTable; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeSubquery; +import sqlancer.materialize.ast.MaterializeSimilarTo; + +public final class MaterializeToStringVisitor extends ToStringVisitor + implements MaterializeVisitor { + + @Override + public void visitSpecific(MaterializeExpression expr) { + MaterializeVisitor.super.visit(expr); + } + + @Override + public void visit(MaterializeConstant constant) { + sb.append(constant.getTextRepresentation()); + } + + @Override + public String get() { + return sb.toString(); + } + + @Override + public void visit(MaterializePostfixOperation op) { + sb.append("("); + visit(op.getExpression()); + sb.append(")"); + sb.append(" "); + sb.append(op.getOperatorTextRepresentation()); + } + + @Override + public void visit(MaterializeColumnValue c) { + sb.append(c.getColumn().getFullQualifiedName()); + } + + @Override + public void visit(MaterializePrefixOperation op) { + sb.append(op.getTextRepresentation()); + sb.append(" ("); + visit(op.getExpression()); + sb.append(")"); + } + + @Override + public void visit(MaterializeFromTable from) { + sb.append(from.getTable().getName()); + } + + @Override + public void visit(MaterializeSubquery subquery) { + sb.append("("); + visit(subquery.getSelect()); + sb.append(") AS "); + sb.append(subquery.getName()); + } + + @Override + public void visit(MaterializeSelect s) { + sb.append("SELECT "); + switch (s.getSelectOption()) { + case DISTINCT: + sb.append("DISTINCT "); + if (s.getDistinctOnClause() != null) { + sb.append("ON ("); + visit(s.getDistinctOnClause()); + sb.append(") "); + } + break; + case ALL: + sb.append(Randomly.fromOptions("ALL ", "")); + break; + default: + throw new AssertionError(); + } + visit(s.getFetchColumns()); + sb.append(" FROM "); + visit(s.getFromList()); + + for (MaterializeJoin j : s.getJoinClauses()) { + sb.append(" "); + switch (j.getType()) { + case INNER: + if (Randomly.getBoolean()) { + sb.append("INNER "); + } + sb.append("JOIN"); + break; + case LEFT: + sb.append("LEFT OUTER JOIN"); + break; + case RIGHT: + sb.append("RIGHT OUTER JOIN"); + break; + case FULL: + sb.append("FULL OUTER JOIN"); + break; + case CROSS: + sb.append("CROSS JOIN"); + break; + default: + throw new AssertionError(j.getType()); + } + sb.append(" "); + visit(j.getTableReference()); + if (j.getType() != MaterializeJoinType.CROSS) { + sb.append(" ON "); + visit(j.getOnClause()); + } + } + + if (s.getWhereClause() != null) { + sb.append(" WHERE "); + visit(s.getWhereClause()); + } + if (!s.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(s.getGroupByExpressions()); + } + if (s.getHavingClause() != null) { + sb.append(" HAVING "); + visit(s.getHavingClause()); + + } + if (!s.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(s.getOrderByClauses()); + } + if (s.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(s.getLimitClause()); + } + + if (s.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(s.getOffsetClause()); + } + } + + @Override + public void visit(MaterializeOrderByTerm op) { + visit(op.getExpr()); + sb.append(" "); + sb.append(op.getOrder()); + } + + @Override + public void visit(MaterializeFunction f) { + sb.append(f.getFunctionName()); + sb.append("("); + int i = 0; + for (MaterializeExpression arg : f.getArguments()) { + if (i++ != 0) { + sb.append(", "); + } + visit(arg); + } + sb.append(")"); + } + + @Override + public void visit(MaterializeCastOperation cast) { + if (cast.getCompoundType().getDataType() == MaterializeDataType.BOOLEAN) { + sb.append("("); + MaterializeExpression expr = cast.getExpression(); + visit(expr); + if (expr.getExpressionType() == MaterializeDataType.TEXT) { + sb.append(" != '')"); + } else if (expr.getExpressionType() == MaterializeDataType.BOOLEAN) { + sb.append(" != FALSE)"); + } else { + sb.append(" != 0)"); + } + } else if (Randomly.getBoolean()) { + if (cast.getCompoundType().getDataType() == MaterializeDataType.REAL + || cast.getCompoundType().getDataType() == MaterializeDataType.FLOAT) { + sb.append("CAST(CAST("); + visit(cast.getExpression()); + sb.append(" AS INT) AS "); + appendType(cast); + sb.append(")"); + } else { + sb.append("CAST("); + visit(cast.getExpression()); + sb.append(" AS "); + appendType(cast); + sb.append(")"); + } + } else { + if (cast.getCompoundType().getDataType() == MaterializeDataType.REAL + || cast.getCompoundType().getDataType() == MaterializeDataType.FLOAT) { + sb.append("("); + visit(cast.getExpression()); + sb.append(")::INT::"); + appendType(cast); + } else { + sb.append("("); + visit(cast.getExpression()); + sb.append(")::"); + appendType(cast); + } + } + } + + private void appendType(MaterializeCastOperation cast) { + MaterializeCompoundDataType compoundType = cast.getCompoundType(); + switch (compoundType.getDataType()) { + case BOOLEAN: + sb.append("BOOLEAN"); + break; + case INT: // TODO support also other int types + sb.append("INT"); + break; + case TEXT: + // TODO: append TEXT, CHAR + sb.append(Randomly.fromOptions("VARCHAR")); + break; + case REAL: + sb.append("FLOAT"); + break; + case DECIMAL: + sb.append("DECIMAL"); + break; + case FLOAT: + sb.append("REAL"); + break; + case BIT: + sb.append("INT"); + break; + default: + throw new AssertionError(cast.getType()); + } + Optional size = compoundType.getSize(); + if (size.isPresent()) { + sb.append("("); + sb.append(size.get()); + sb.append(")"); + } + } + + @Override + public void visit(MaterializeBetweenOperation op) { + sb.append("("); + visit(op.getExpr()); + sb.append(") BETWEEN "); + sb.append("("); + visit(op.getLeft()); + sb.append(") AND ("); + visit(op.getRight()); + sb.append(")"); + } + + @Override + public void visit(MaterializeInOperation op) { + sb.append("("); + visit(op.getExpr()); + sb.append(")"); + if (!op.isTrue()) { + sb.append(" NOT"); + } + sb.append(" IN ("); + visit(op.getListElements()); + sb.append(")"); + } + + @Override + public void visit(MaterializePostfixText op) { + visit(op.getExpr()); + sb.append(op.getText()); + } + + @Override + public void visit(MaterializeAggregate op) { + sb.append(op.getFunction()); + sb.append("("); + visit(op.getArgs()); + sb.append(")"); + } + + @Override + public void visit(MaterializeSimilarTo op) { + sb.append("("); + visit(op.getString()); + sb.append(" SIMILAR TO "); + visit(op.getSimilarTo()); + if (op.getEscapeCharacter() != null) { + visit(op.getEscapeCharacter()); + } + sb.append(")"); + } + + @Override + public void visit(MaterializePOSIXRegularExpression op) { + visit(op.getString()); + sb.append(op.getOp().getStringRepresentation()); + visit(op.getRegex()); + } + + @Override + public void visit(MaterializeBinaryLogicalOperation op) { + super.visit((BinaryOperation) op); + } + + @Override + public void visit(MaterializeLikeOperation op) { + super.visit((BinaryOperation) op); + } + +} diff --git a/src/sqlancer/materialize/MaterializeVisitor.java b/src/sqlancer/materialize/MaterializeVisitor.java new file mode 100644 index 000000000..40ed4e3f2 --- /dev/null +++ b/src/sqlancer/materialize/MaterializeVisitor.java @@ -0,0 +1,127 @@ +package sqlancer.materialize; + +import java.util.List; + +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.ast.MaterializeAggregate; +import sqlancer.materialize.ast.MaterializeBetweenOperation; +import sqlancer.materialize.ast.MaterializeBinaryLogicalOperation; +import sqlancer.materialize.ast.MaterializeCastOperation; +import sqlancer.materialize.ast.MaterializeColumnValue; +import sqlancer.materialize.ast.MaterializeConstant; +import sqlancer.materialize.ast.MaterializeExpression; +import sqlancer.materialize.ast.MaterializeFunction; +import sqlancer.materialize.ast.MaterializeInOperation; +import sqlancer.materialize.ast.MaterializeLikeOperation; +import sqlancer.materialize.ast.MaterializeOrderByTerm; +import sqlancer.materialize.ast.MaterializePOSIXRegularExpression; +import sqlancer.materialize.ast.MaterializePostfixOperation; +import sqlancer.materialize.ast.MaterializePostfixText; +import sqlancer.materialize.ast.MaterializePrefixOperation; +import sqlancer.materialize.ast.MaterializeSelect; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeFromTable; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeSubquery; +import sqlancer.materialize.ast.MaterializeSimilarTo; +import sqlancer.materialize.gen.MaterializeExpressionGenerator; + +public interface MaterializeVisitor { + + void visit(MaterializeConstant constant); + + void visit(MaterializePostfixOperation op); + + void visit(MaterializeColumnValue c); + + void visit(MaterializePrefixOperation op); + + void visit(MaterializeSelect op); + + void visit(MaterializeOrderByTerm op); + + void visit(MaterializeFunction f); + + void visit(MaterializeCastOperation cast); + + void visit(MaterializeBetweenOperation op); + + void visit(MaterializeInOperation op); + + void visit(MaterializePostfixText op); + + void visit(MaterializeAggregate op); + + void visit(MaterializeSimilarTo op); + + void visit(MaterializePOSIXRegularExpression op); + + void visit(MaterializeFromTable from); + + void visit(MaterializeSubquery subquery); + + void visit(MaterializeBinaryLogicalOperation op); + + void visit(MaterializeLikeOperation op); + + default void visit(MaterializeExpression expression) { + if (expression instanceof MaterializeConstant) { + visit((MaterializeConstant) expression); + } else if (expression instanceof MaterializePostfixOperation) { + visit((MaterializePostfixOperation) expression); + } else if (expression instanceof MaterializeColumnValue) { + visit((MaterializeColumnValue) expression); + } else if (expression instanceof MaterializePrefixOperation) { + visit((MaterializePrefixOperation) expression); + } else if (expression instanceof MaterializeSelect) { + visit((MaterializeSelect) expression); + } else if (expression instanceof MaterializeOrderByTerm) { + visit((MaterializeOrderByTerm) expression); + } else if (expression instanceof MaterializeFunction) { + visit((MaterializeFunction) expression); + } else if (expression instanceof MaterializeCastOperation) { + visit((MaterializeCastOperation) expression); + } else if (expression instanceof MaterializeBetweenOperation) { + visit((MaterializeBetweenOperation) expression); + } else if (expression instanceof MaterializeInOperation) { + visit((MaterializeInOperation) expression); + } else if (expression instanceof MaterializeAggregate) { + visit((MaterializeAggregate) expression); + } else if (expression instanceof MaterializePostfixText) { + visit((MaterializePostfixText) expression); + } else if (expression instanceof MaterializeSimilarTo) { + visit((MaterializeSimilarTo) expression); + } else if (expression instanceof MaterializePOSIXRegularExpression) { + visit((MaterializePOSIXRegularExpression) expression); + } else if (expression instanceof MaterializeFromTable) { + visit((MaterializeFromTable) expression); + } else if (expression instanceof MaterializeSubquery) { + visit((MaterializeSubquery) expression); + } else if (expression instanceof MaterializeLikeOperation) { + visit((MaterializeLikeOperation) expression); + } else { + throw new AssertionError(expression); + } + } + + static String asString(MaterializeExpression expr) { + MaterializeToStringVisitor visitor = new MaterializeToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } + + static String asExpectedValues(MaterializeExpression expr) { + MaterializeExpectedValueVisitor v = new MaterializeExpectedValueVisitor(); + v.visit(expr); + return v.get(); + } + + static String getExpressionAsString(MaterializeGlobalState globalState, MaterializeDataType type, + List columns) { + MaterializeExpression expression = MaterializeExpressionGenerator.generateExpression(globalState, columns, + type); + MaterializeToStringVisitor visitor = new MaterializeToStringVisitor(); + visitor.visit(expression); + return visitor.get(); + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeAggregate.java b/src/sqlancer/materialize/ast/MaterializeAggregate.java new file mode 100644 index 000000000..e9dc638f3 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeAggregate.java @@ -0,0 +1,58 @@ +package sqlancer.materialize.ast; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.ast.FunctionNode; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.ast.MaterializeAggregate.MaterializeAggregateFunction; + +/** + * @see Built-in Aggregate Functions + */ +public class MaterializeAggregate extends FunctionNode + implements MaterializeExpression { + + public enum MaterializeAggregateFunction { + AVG(MaterializeDataType.INT, MaterializeDataType.FLOAT, MaterializeDataType.REAL, MaterializeDataType.DECIMAL), + BIT_AND(MaterializeDataType.INT), BIT_OR(MaterializeDataType.INT), BOOL_AND(MaterializeDataType.BOOLEAN), + BOOL_OR(MaterializeDataType.BOOLEAN), COUNT(MaterializeDataType.INT), MAX, MIN, + SUM(MaterializeDataType.INT, MaterializeDataType.FLOAT, MaterializeDataType.REAL, MaterializeDataType.DECIMAL); + + private MaterializeDataType[] supportedReturnTypes; + + MaterializeAggregateFunction(MaterializeDataType... supportedReturnTypes) { + this.supportedReturnTypes = supportedReturnTypes.clone(); + } + + public List getTypes(MaterializeDataType returnType) { + return Arrays.asList(returnType); + } + + public boolean supportsReturnType(MaterializeDataType returnType) { + return Arrays.asList(supportedReturnTypes).stream().anyMatch(t -> t == returnType) + || supportedReturnTypes.length == 0; + } + + public static List getAggregates(MaterializeDataType type) { + return Arrays.asList(values()).stream().filter(p -> p.supportsReturnType(type)) + .collect(Collectors.toList()); + } + + public MaterializeDataType getRandomReturnType() { + if (supportedReturnTypes.length == 0) { + return Randomly.fromOptions(MaterializeDataType.getRandomType()); + } else { + return Randomly.fromOptions(supportedReturnTypes); + } + } + + } + + public MaterializeAggregate(List args, MaterializeAggregateFunction func) { + super(func, args); + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeAlias.java b/src/sqlancer/materialize/ast/MaterializeAlias.java new file mode 100644 index 000000000..ec5445125 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeAlias.java @@ -0,0 +1,35 @@ +package sqlancer.materialize.ast; + +import sqlancer.common.visitor.UnaryOperation; + +public class MaterializeAlias implements UnaryOperation, MaterializeExpression { + + private final MaterializeExpression expr; + private final String alias; + + public MaterializeAlias(MaterializeExpression expr, String alias) { + this.expr = expr; + this.alias = alias; + } + + @Override + public MaterializeExpression getExpression() { + return expr; + } + + @Override + public String getOperatorRepresentation() { + return " as " + alias; + } + + @Override + public OperatorKind getOperatorKind() { + return OperatorKind.POSTFIX; + } + + @Override + public boolean omitBracketsWhenPrinting() { + return true; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeBetweenOperation.java b/src/sqlancer/materialize/ast/MaterializeBetweenOperation.java new file mode 100644 index 000000000..dcb7b7cdb --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeBetweenOperation.java @@ -0,0 +1,66 @@ +package sqlancer.materialize.ast; + +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.ast.MaterializeBinaryComparisonOperation.MaterializeBinaryComparisonOperator; +import sqlancer.materialize.ast.MaterializeBinaryLogicalOperation.BinaryLogicalOperator; + +public final class MaterializeBetweenOperation implements MaterializeExpression { + + private final MaterializeExpression expr; + private final MaterializeExpression left; + private final MaterializeExpression right; + private final boolean isSymmetric; + + public MaterializeBetweenOperation(MaterializeExpression expr, MaterializeExpression left, + MaterializeExpression right, boolean symmetric) { + this.expr = expr; + this.left = left; + this.right = right; + isSymmetric = symmetric; + } + + public MaterializeExpression getExpr() { + return expr; + } + + public MaterializeExpression getLeft() { + return left; + } + + public MaterializeExpression getRight() { + return right; + } + + public boolean isSymmetric() { + return isSymmetric; + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeBinaryComparisonOperation leftComparison = new MaterializeBinaryComparisonOperation(left, expr, + MaterializeBinaryComparisonOperator.LESS_EQUALS); + MaterializeBinaryComparisonOperation rightComparison = new MaterializeBinaryComparisonOperation(expr, right, + MaterializeBinaryComparisonOperator.LESS_EQUALS); + MaterializeBinaryLogicalOperation andOperation = new MaterializeBinaryLogicalOperation(leftComparison, + rightComparison, MaterializeBinaryLogicalOperation.BinaryLogicalOperator.AND); + if (isSymmetric) { + MaterializeBinaryComparisonOperation leftComparison2 = new MaterializeBinaryComparisonOperation(right, expr, + MaterializeBinaryComparisonOperator.LESS_EQUALS); + MaterializeBinaryComparisonOperation rightComparison2 = new MaterializeBinaryComparisonOperation(expr, left, + MaterializeBinaryComparisonOperator.LESS_EQUALS); + MaterializeBinaryLogicalOperation andOperation2 = new MaterializeBinaryLogicalOperation(leftComparison2, + rightComparison2, MaterializeBinaryLogicalOperation.BinaryLogicalOperator.AND); + MaterializeBinaryLogicalOperation orOp = new MaterializeBinaryLogicalOperation(andOperation, andOperation2, + BinaryLogicalOperator.OR); + return orOp.getExpectedValue(); + } else { + return andOperation.getExpectedValue(); + } + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeBinaryArithmeticOperation.java b/src/sqlancer/materialize/ast/MaterializeBinaryArithmeticOperation.java new file mode 100644 index 000000000..0ba48924f --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeBinaryArithmeticOperation.java @@ -0,0 +1,103 @@ +package sqlancer.materialize.ast; + +import java.util.function.BinaryOperator; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.ast.MaterializeBinaryArithmeticOperation.MaterializeBinaryOperator; + +public class MaterializeBinaryArithmeticOperation + extends BinaryOperatorNode implements MaterializeExpression { + + public enum MaterializeBinaryOperator implements Operator { + + ADDITION("+") { + @Override + public MaterializeConstant apply(MaterializeConstant left, MaterializeConstant right) { + return applyBitOperation(left, right, (l, r) -> l + r); + } + + }, + SUBTRACTION("-") { + @Override + public MaterializeConstant apply(MaterializeConstant left, MaterializeConstant right) { + return applyBitOperation(left, right, (l, r) -> l - r); + } + }, + MULTIPLICATION("*") { + @Override + public MaterializeConstant apply(MaterializeConstant left, MaterializeConstant right) { + return applyBitOperation(left, right, (l, r) -> l * r); + } + }, + DIVISION("/") { + + @Override + public MaterializeConstant apply(MaterializeConstant left, MaterializeConstant right) { + return applyBitOperation(left, right, (l, r) -> r == 0 ? -1 : l / r); + + } + + }, + MODULO("%") { + @Override + public MaterializeConstant apply(MaterializeConstant left, MaterializeConstant right) { + return applyBitOperation(left, right, (l, r) -> r == 0 ? -1 : l % r); + + } + }; + + private String textRepresentation; + + private static MaterializeConstant applyBitOperation(MaterializeConstant left, MaterializeConstant right, + BinaryOperator op) { + if (left.isNull() || right.isNull()) { + return MaterializeConstant.createNullConstant(); + } else { + long leftVal = left.cast(MaterializeDataType.INT).asInt(); + long rightVal = right.cast(MaterializeDataType.INT).asInt(); + long value = op.apply(leftVal, rightVal); + return MaterializeConstant.createIntConstant(value); + } + } + + MaterializeBinaryOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + public abstract MaterializeConstant apply(MaterializeConstant left, MaterializeConstant right); + + public static MaterializeBinaryOperator getRandom() { + return Randomly.fromOptions(values()); + } + + } + + public MaterializeBinaryArithmeticOperation(MaterializeExpression left, MaterializeExpression right, + MaterializeBinaryOperator op) { + super(left, right, op); + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeConstant leftExpected = getLeft().getExpectedValue(); + MaterializeConstant rightExpected = getRight().getExpectedValue(); + if (leftExpected == null || rightExpected == null) { + return null; + } + return getOp().apply(leftExpected, rightExpected); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.INT; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeBinaryBitOperation.java b/src/sqlancer/materialize/ast/MaterializeBinaryBitOperation.java new file mode 100644 index 000000000..8f143da76 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeBinaryBitOperation.java @@ -0,0 +1,46 @@ +package sqlancer.materialize.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.ast.MaterializeBinaryBitOperation.MaterializeBinaryBitOperator; + +public class MaterializeBinaryBitOperation extends + BinaryOperatorNode implements MaterializeExpression { + + public enum MaterializeBinaryBitOperator implements Operator { + BITWISE_AND("&"), // + BITWISE_OR("|"), // + BITWISE_XOR("#"), // + BITWISE_SHIFT_LEFT("<<"), // + BITWISE_SHIFT_RIGHT(">>"); + + private String text; + + MaterializeBinaryBitOperator(String text) { + this.text = text; + } + + public static MaterializeBinaryBitOperator getRandom() { + return Randomly.fromOptions(MaterializeBinaryBitOperator.values()); + } + + @Override + public String getTextRepresentation() { + return text; + } + + } + + public MaterializeBinaryBitOperation(MaterializeBinaryBitOperator op, MaterializeExpression left, + MaterializeExpression right) { + super(left, right, op); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.INT; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeBinaryComparisonOperation.java b/src/sqlancer/materialize/ast/MaterializeBinaryComparisonOperation.java new file mode 100644 index 000000000..25121935a --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeBinaryComparisonOperation.java @@ -0,0 +1,121 @@ +package sqlancer.materialize.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.ast.MaterializeBinaryComparisonOperation.MaterializeBinaryComparisonOperator; + +public class MaterializeBinaryComparisonOperation + extends BinaryOperatorNode + implements MaterializeExpression { + + public enum MaterializeBinaryComparisonOperator implements Operator { + EQUALS("=") { + @Override + public MaterializeConstant getExpectedValue(MaterializeConstant leftVal, MaterializeConstant rightVal) { + return leftVal.isEquals(rightVal); + } + }, + NOT_EQUALS("!=") { + @Override + public MaterializeConstant getExpectedValue(MaterializeConstant leftVal, MaterializeConstant rightVal) { + MaterializeConstant isEquals = leftVal.isEquals(rightVal); + if (isEquals.isBoolean()) { + return MaterializeConstant.createBooleanConstant(!isEquals.asBoolean()); + } + return isEquals; + } + }, + LESS("<") { + + @Override + public MaterializeConstant getExpectedValue(MaterializeConstant leftVal, MaterializeConstant rightVal) { + return leftVal.isLessThan(rightVal); + } + }, + LESS_EQUALS("<=") { + + @Override + public MaterializeConstant getExpectedValue(MaterializeConstant leftVal, MaterializeConstant rightVal) { + MaterializeConstant lessThan = leftVal.isLessThan(rightVal); + if (lessThan.isBoolean() && !lessThan.asBoolean()) { + return leftVal.isEquals(rightVal); + } else { + return lessThan; + } + } + }, + GREATER(">") { + @Override + public MaterializeConstant getExpectedValue(MaterializeConstant leftVal, MaterializeConstant rightVal) { + MaterializeConstant equals = leftVal.isEquals(rightVal); + if (equals.isBoolean() && equals.asBoolean()) { + return MaterializeConstant.createFalse(); + } else { + MaterializeConstant applyLess = leftVal.isLessThan(rightVal); + if (applyLess.isNull()) { + return MaterializeConstant.createNullConstant(); + } + return MaterializePrefixOperation.PrefixOperator.NOT.getExpectedValue(applyLess); + } + } + }, + GREATER_EQUALS(">=") { + + @Override + public MaterializeConstant getExpectedValue(MaterializeConstant leftVal, MaterializeConstant rightVal) { + MaterializeConstant equals = leftVal.isEquals(rightVal); + if (equals.isBoolean() && equals.asBoolean()) { + return MaterializeConstant.createTrue(); + } else { + MaterializeConstant applyLess = leftVal.isLessThan(rightVal); + if (applyLess.isNull()) { + return MaterializeConstant.createNullConstant(); + } + return MaterializePrefixOperation.PrefixOperator.NOT.getExpectedValue(applyLess); + } + } + + }; + + private final String textRepresentation; + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + MaterializeBinaryComparisonOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public abstract MaterializeConstant getExpectedValue(MaterializeConstant leftVal, MaterializeConstant rightVal); + + public static MaterializeBinaryComparisonOperator getRandom() { + return Randomly.fromOptions(MaterializeBinaryComparisonOperator.values()); + } + + } + + public MaterializeBinaryComparisonOperation(MaterializeExpression left, MaterializeExpression right, + MaterializeBinaryComparisonOperator op) { + super(left, right, op); + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeConstant leftExpectedValue = getLeft().getExpectedValue(); + MaterializeConstant rightExpectedValue = getRight().getExpectedValue(); + if (leftExpectedValue == null || rightExpectedValue == null) { + return null; + } + return getOp().getExpectedValue(leftExpectedValue, rightExpectedValue); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeBinaryLogicalOperation.java b/src/sqlancer/materialize/ast/MaterializeBinaryLogicalOperation.java new file mode 100644 index 000000000..8b8e457e2 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeBinaryLogicalOperation.java @@ -0,0 +1,89 @@ +package sqlancer.materialize.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.ast.MaterializeBinaryLogicalOperation.BinaryLogicalOperator; + +public class MaterializeBinaryLogicalOperation extends BinaryOperatorNode + implements MaterializeExpression { + + public enum BinaryLogicalOperator implements Operator { + AND { + @Override + public MaterializeConstant apply(MaterializeConstant left, MaterializeConstant right) { + MaterializeConstant leftBool = left.cast(MaterializeDataType.BOOLEAN); + MaterializeConstant rightBool = right.cast(MaterializeDataType.BOOLEAN); + if (leftBool.isNull()) { + if (rightBool.isNull()) { + return MaterializeConstant.createNullConstant(); + } else { + if (rightBool.asBoolean()) { + return MaterializeConstant.createNullConstant(); + } else { + return MaterializeConstant.createFalse(); + } + } + } else if (!leftBool.asBoolean()) { + return MaterializeConstant.createFalse(); + } + assert leftBool.asBoolean(); + if (rightBool.isNull()) { + return MaterializeConstant.createNullConstant(); + } else { + return MaterializeConstant.createBooleanConstant(rightBool.isBoolean() && rightBool.asBoolean()); + } + } + }, + OR { + @Override + public MaterializeConstant apply(MaterializeConstant left, MaterializeConstant right) { + MaterializeConstant leftBool = left.cast(MaterializeDataType.BOOLEAN); + MaterializeConstant rightBool = right.cast(MaterializeDataType.BOOLEAN); + if (leftBool.isBoolean() && leftBool.asBoolean()) { + return MaterializeConstant.createTrue(); + } + if (rightBool.isBoolean() && rightBool.asBoolean()) { + return MaterializeConstant.createTrue(); + } + if (leftBool.isNull() || rightBool.isNull()) { + return MaterializeConstant.createNullConstant(); + } + return MaterializeConstant.createFalse(); + } + }; + + public abstract MaterializeConstant apply(MaterializeConstant left, MaterializeConstant right); + + public static BinaryLogicalOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return toString(); + } + } + + public MaterializeBinaryLogicalOperation(MaterializeExpression left, MaterializeExpression right, + BinaryLogicalOperator op) { + super(left, right, op); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeConstant leftExpectedValue = getLeft().getExpectedValue(); + MaterializeConstant rightExpectedValue = getRight().getExpectedValue(); + if (leftExpectedValue == null || rightExpectedValue == null) { + return null; + } + return getOp().apply(leftExpectedValue, rightExpectedValue); + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeBinaryRangeOperation.java b/src/sqlancer/materialize/ast/MaterializeBinaryRangeOperation.java new file mode 100644 index 000000000..cf1ef66ef --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeBinaryRangeOperation.java @@ -0,0 +1,74 @@ +package sqlancer.materialize.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializeBinaryRangeOperation extends BinaryNode + implements MaterializeExpression { + + private final String op; + + public enum MaterializeBinaryRangeOperator implements Operator { + UNION("+"), INTERSECTION("*"), DIFFERENCE("-"); + + private final String textRepresentation; + + MaterializeBinaryRangeOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + public static MaterializeBinaryRangeOperator getRandom() { + return Randomly.fromOptions(values()); + } + + } + + public enum MaterializeBinaryRangeComparisonOperator { + CONTAINS_RANGE_OR_ELEMENT("@>"), RANGE_OR_ELEMENT_IS_CONTAINED("<@"), OVERLAP("&&"), STRICT_LEFT_OF("<<"), + STRICT_RIGHT_OF(">>"); + + private final String textRepresentation; + + MaterializeBinaryRangeComparisonOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public String getTextRepresentation() { + return textRepresentation; + } + + public static MaterializeBinaryRangeComparisonOperator getRandom() { + return Randomly.fromOptions(values()); + } + } + + public MaterializeBinaryRangeOperation(MaterializeBinaryRangeComparisonOperator op, MaterializeExpression left, + MaterializeExpression right) { + super(left, right); + this.op = op.getTextRepresentation(); + } + + public MaterializeBinaryRangeOperation(MaterializeBinaryRangeOperator op, MaterializeExpression left, + MaterializeExpression right) { + super(left, right); + this.op = op.getTextRepresentation(); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + + @Override + public String getOperatorRepresentation() { + return op; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeCastOperation.java b/src/sqlancer/materialize/ast/MaterializeCastOperation.java new file mode 100644 index 000000000..ff7e20180 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeCastOperation.java @@ -0,0 +1,45 @@ +package sqlancer.materialize.ast; + +import sqlancer.materialize.MaterializeCompoundDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializeCastOperation implements MaterializeExpression { + + private final MaterializeExpression expression; + private final MaterializeCompoundDataType type; + + public MaterializeCastOperation(MaterializeExpression expression, MaterializeCompoundDataType type) { + if (expression == null) { + throw new AssertionError(); + } + this.expression = expression; + this.type = type; + } + + @Override + public MaterializeDataType getExpressionType() { + return type.getDataType(); + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeConstant expectedValue = expression.getExpectedValue(); + if (expectedValue == null) { + return null; + } + return expectedValue.cast(type.getDataType()); + } + + public MaterializeExpression getExpression() { + return expression; + } + + public MaterializeDataType getType() { + return type.getDataType(); + } + + public MaterializeCompoundDataType getCompoundType() { + return type; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeColumnValue.java b/src/sqlancer/materialize/ast/MaterializeColumnValue.java new file mode 100644 index 000000000..054f34ffd --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeColumnValue.java @@ -0,0 +1,34 @@ +package sqlancer.materialize.ast; + +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializeColumnValue implements MaterializeExpression { + + private final MaterializeColumn c; + private final MaterializeConstant expectedValue; + + public MaterializeColumnValue(MaterializeColumn c, MaterializeConstant expectedValue) { + this.c = c; + this.expectedValue = expectedValue; + } + + @Override + public MaterializeDataType getExpressionType() { + return c.getType(); + } + + @Override + public MaterializeConstant getExpectedValue() { + return expectedValue; + } + + public static MaterializeColumnValue create(MaterializeColumn c, MaterializeConstant expected) { + return new MaterializeColumnValue(c, expected); + } + + public MaterializeColumn getColumn() { + return c; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeConcatOperation.java b/src/sqlancer/materialize/ast/MaterializeConcatOperation.java new file mode 100644 index 000000000..c1963c871 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeConcatOperation.java @@ -0,0 +1,37 @@ +package sqlancer.materialize.ast; + +import sqlancer.common.ast.BinaryNode; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializeConcatOperation extends BinaryNode implements MaterializeExpression { + + public MaterializeConcatOperation(MaterializeExpression left, MaterializeExpression right) { + super(left, right); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.TEXT; + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeConstant leftExpectedValue = getLeft().getExpectedValue(); + MaterializeConstant rightExpectedValue = getRight().getExpectedValue(); + if (leftExpectedValue == null || rightExpectedValue == null) { + return null; + } + if (leftExpectedValue.isNull() || rightExpectedValue.isNull()) { + return MaterializeConstant.createNullConstant(); + } + String leftStr = leftExpectedValue.cast(MaterializeDataType.TEXT).getUnquotedTextRepresentation(); + String rightStr = rightExpectedValue.cast(MaterializeDataType.TEXT).getUnquotedTextRepresentation(); + return MaterializeConstant.createTextConstant(leftStr + rightStr); + } + + @Override + public String getOperatorRepresentation() { + return "||"; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeConstant.java b/src/sqlancer/materialize/ast/MaterializeConstant.java new file mode 100644 index 000000000..dbf668806 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeConstant.java @@ -0,0 +1,516 @@ +package sqlancer.materialize.ast; + +import java.math.BigDecimal; + +import sqlancer.IgnoreMeException; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public abstract class MaterializeConstant implements MaterializeExpression { + + public abstract String getTextRepresentation(); + + public abstract String getUnquotedTextRepresentation(); + + public static class BooleanConstant extends MaterializeConstant { + + private final boolean value; + + public BooleanConstant(boolean value) { + this.value = value; + } + + @Override + public String getTextRepresentation() { + return value ? "TRUE" : "FALSE"; + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + + @Override + public boolean asBoolean() { + return value; + } + + @Override + public boolean isBoolean() { + return true; + } + + @Override + public MaterializeConstant isEquals(MaterializeConstant rightVal) { + if (rightVal.isNull()) { + return MaterializeConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return MaterializeConstant.createBooleanConstant(value == rightVal.asBoolean()); + } else if (rightVal.isString()) { + return MaterializeConstant + .createBooleanConstant(value == rightVal.cast(MaterializeDataType.BOOLEAN).asBoolean()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + protected MaterializeConstant isLessThan(MaterializeConstant rightVal) { + if (rightVal.isNull()) { + return MaterializeConstant.createNullConstant(); + } else if (rightVal.isString()) { + return isLessThan(rightVal.cast(MaterializeDataType.BOOLEAN)); + } else { + assert rightVal.isBoolean(); + return MaterializeConstant.createBooleanConstant((value ? 1 : 0) < (rightVal.asBoolean() ? 1 : 0)); + } + } + + @Override + public MaterializeConstant cast(MaterializeDataType type) { + switch (type) { + case BOOLEAN: + return this; + case INT: + return MaterializeConstant.createIntConstant(value ? 1 : 0); + case TEXT: + return MaterializeConstant.createTextConstant(value ? "true" : "false"); + default: + return null; + } + } + + @Override + public String getUnquotedTextRepresentation() { + return getTextRepresentation(); + } + + } + + public static class MaterializeNullConstant extends MaterializeConstant { + + @Override + public String getTextRepresentation() { + return "NULL"; + } + + @Override + public MaterializeDataType getExpressionType() { + return null; + } + + @Override + public boolean isNull() { + return true; + } + + @Override + public MaterializeConstant isEquals(MaterializeConstant rightVal) { + return MaterializeConstant.createNullConstant(); + } + + @Override + protected MaterializeConstant isLessThan(MaterializeConstant rightVal) { + return MaterializeConstant.createNullConstant(); + } + + @Override + public MaterializeConstant cast(MaterializeDataType type) { + return MaterializeConstant.createNullConstant(); + } + + @Override + public String getUnquotedTextRepresentation() { + return getTextRepresentation(); + } + + } + + public static class StringConstant extends MaterializeConstant { + + private final String value; + + public StringConstant(String value) { + this.value = value; + } + + @Override + public String getTextRepresentation() { + return String.format("'%s'", value.replace("'", "''")); + } + + @Override + public MaterializeConstant isEquals(MaterializeConstant rightVal) { + if (rightVal.isNull()) { + return MaterializeConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return cast(MaterializeDataType.INT).isEquals(rightVal.cast(MaterializeDataType.INT)); + } else if (rightVal.isBoolean()) { + return cast(MaterializeDataType.BOOLEAN).isEquals(rightVal.cast(MaterializeDataType.BOOLEAN)); + } else if (rightVal.isString()) { + return MaterializeConstant.createBooleanConstant(value.contentEquals(rightVal.asString())); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + protected MaterializeConstant isLessThan(MaterializeConstant rightVal) { + if (rightVal.isNull()) { + return MaterializeConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return cast(MaterializeDataType.INT).isLessThan(rightVal.cast(MaterializeDataType.INT)); + } else if (rightVal.isBoolean()) { + return cast(MaterializeDataType.BOOLEAN).isLessThan(rightVal.cast(MaterializeDataType.BOOLEAN)); + } else if (rightVal.isString()) { + return MaterializeConstant.createBooleanConstant(value.compareTo(rightVal.asString()) < 0); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + public MaterializeConstant cast(MaterializeDataType type) { + if (type == MaterializeDataType.TEXT) { + return this; + } + String s = value.trim(); + switch (type) { + case BOOLEAN: + try { + return MaterializeConstant.createBooleanConstant(Long.parseLong(s) != 0); + } catch (NumberFormatException e) { + } + switch (s.toUpperCase()) { + case "T": + case "TR": + case "TRU": + case "TRUE": + case "1": + case "YES": + case "YE": + case "Y": + case "ON": + return MaterializeConstant.createTrue(); + case "F": + case "FA": + case "FAL": + case "FALS": + case "FALSE": + case "N": + case "NO": + case "OF": + case "OFF": + default: + return MaterializeConstant.createFalse(); + } + case INT: + try { + return MaterializeConstant.createIntConstant(Long.parseLong(s)); + } catch (NumberFormatException e) { + return MaterializeConstant.createIntConstant(-1); + } + case TEXT: + return this; + default: + return null; + } + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.TEXT; + } + + @Override + public boolean isString() { + return true; + } + + @Override + public String asString() { + return value; + } + + @Override + public String getUnquotedTextRepresentation() { + return value; + } + + } + + public static class IntConstant extends MaterializeConstant { + + private final long val; + + public IntConstant(long val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + return String.valueOf(val); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.INT; + } + + @Override + public long asInt() { + return val; + } + + @Override + public boolean isInt() { + return true; + } + + @Override + public MaterializeConstant isEquals(MaterializeConstant rightVal) { + if (rightVal.isNull()) { + return MaterializeConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return cast(MaterializeDataType.BOOLEAN).isEquals(rightVal); + } else if (rightVal.isInt()) { + return MaterializeConstant.createBooleanConstant(val == rightVal.asInt()); + } else if (rightVal.isString()) { + return MaterializeConstant.createBooleanConstant(val == rightVal.cast(MaterializeDataType.INT).asInt()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + protected MaterializeConstant isLessThan(MaterializeConstant rightVal) { + if (rightVal.isNull()) { + return MaterializeConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return MaterializeConstant.createBooleanConstant(val < rightVal.asInt()); + } else if (rightVal.isBoolean()) { + throw new AssertionError(rightVal); + } else if (rightVal.isString()) { + return MaterializeConstant.createBooleanConstant(val < rightVal.cast(MaterializeDataType.INT).asInt()); + } else { + throw new IgnoreMeException(); + } + + } + + @Override + public MaterializeConstant cast(MaterializeDataType type) { + switch (type) { + case BOOLEAN: + return MaterializeConstant.createBooleanConstant(val != 0); + case INT: + return this; + case TEXT: + return MaterializeConstant.createTextConstant(String.valueOf(val)); + default: + return null; + } + } + + @Override + public String getUnquotedTextRepresentation() { + return getTextRepresentation(); + } + + } + + public static MaterializeConstant createNullConstant() { + return new MaterializeNullConstant(); + } + + public String asString() { + throw new UnsupportedOperationException(this.toString()); + } + + public boolean isString() { + return false; + } + + public static MaterializeConstant createIntConstant(long val) { + return new IntConstant(val); + } + + public static MaterializeConstant createBooleanConstant(boolean val) { + return new BooleanConstant(val); + } + + @Override + public MaterializeConstant getExpectedValue() { + return this; + } + + public boolean isNull() { + return false; + } + + public boolean asBoolean() { + throw new UnsupportedOperationException(this.toString()); + } + + public static MaterializeConstant createFalse() { + return createBooleanConstant(false); + } + + public static MaterializeConstant createTrue() { + return createBooleanConstant(true); + } + + public long asInt() { + throw new UnsupportedOperationException(this.toString()); + } + + public boolean isBoolean() { + return false; + } + + public abstract MaterializeConstant isEquals(MaterializeConstant rightVal); + + public boolean isInt() { + return false; + } + + protected abstract MaterializeConstant isLessThan(MaterializeConstant rightVal); + + @Override + public String toString() { + return getTextRepresentation(); + } + + public abstract MaterializeConstant cast(MaterializeDataType type); + + public static MaterializeConstant createTextConstant(String string) { + return new StringConstant(string); + } + + public abstract static class MaterializeConstantBase extends MaterializeConstant { + + @Override + public String getUnquotedTextRepresentation() { + return null; + } + + @Override + public MaterializeConstant isEquals(MaterializeConstant rightVal) { + return null; + } + + @Override + protected MaterializeConstant isLessThan(MaterializeConstant rightVal) { + return null; + } + + @Override + public MaterializeConstant cast(MaterializeDataType type) { + return null; + } + } + + public static class DecimalConstant extends MaterializeConstantBase { + + private final BigDecimal val; + + public DecimalConstant(BigDecimal val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + return String.valueOf(val); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.DECIMAL; + } + + } + + public static class FloatConstant extends MaterializeConstantBase { + + private final float val; + + public FloatConstant(float val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + if (Double.isFinite(val)) { + return String.valueOf(val); + } else { + return "'" + val + "'"; + } + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.FLOAT; + } + + } + + public static class DoubleConstant extends MaterializeConstantBase { + + private final double val; + + public DoubleConstant(double val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + if (Double.isFinite(val)) { + return String.valueOf(val); + } else { + return "'" + val + "'"; + } + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.FLOAT; + } + + } + + public static class BitConstant extends MaterializeConstantBase { + + private final long val; + + public BitConstant(long val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + return String.format("%d", val); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BIT; + } + + } + + public static MaterializeConstant createDecimalConstant(BigDecimal bigDecimal) { + return new DecimalConstant(bigDecimal); + } + + public static MaterializeConstant createFloatConstant(float val) { + return new FloatConstant(val); + } + + public static MaterializeConstant createDoubleConstant(double val) { + return new DoubleConstant(val); + } + + public static MaterializeExpression createBitConstant(long integer) { + return new BitConstant(integer); + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeExpression.java b/src/sqlancer/materialize/ast/MaterializeExpression.java new file mode 100644 index 000000000..70444d8ee --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeExpression.java @@ -0,0 +1,16 @@ +package sqlancer.materialize.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public interface MaterializeExpression extends Expression { + + default MaterializeDataType getExpressionType() { + return null; + } + + default MaterializeConstant getExpectedValue() { + return null; + } +} diff --git a/src/sqlancer/materialize/ast/MaterializeFunction.java b/src/sqlancer/materialize/ast/MaterializeFunction.java new file mode 100644 index 000000000..24fdce866 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeFunction.java @@ -0,0 +1,256 @@ +package sqlancer.materialize.ast; + +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializeFunction implements MaterializeExpression { + + private final String func; + private final MaterializeExpression[] args; + private final MaterializeDataType returnType; + private MaterializeFunctionWithResult functionWithKnownResult; + + public MaterializeFunction(MaterializeFunctionWithResult func, MaterializeDataType returnType, + MaterializeExpression... args) { + functionWithKnownResult = func; + this.func = func.getName(); + this.returnType = returnType; + this.args = args.clone(); + } + + public MaterializeFunction(MaterializeFunctionWithUnknownResult f, MaterializeDataType returnType, + MaterializeExpression... args) { + this.func = f.getName(); + this.returnType = returnType; + this.args = args.clone(); + } + + public String getFunctionName() { + return func; + } + + public MaterializeExpression[] getArguments() { + return args.clone(); + } + + public enum MaterializeFunctionWithResult { + ABS(1, "abs") { + + @Override + public MaterializeConstant apply(MaterializeConstant[] evaluatedArgs, MaterializeExpression... args) { + if (evaluatedArgs[0].isNull()) { + return MaterializeConstant.createNullConstant(); + } else { + return MaterializeConstant + .createIntConstant(Math.abs(evaluatedArgs[0].cast(MaterializeDataType.INT).asInt())); + } + } + + @Override + public boolean supportsReturnType(MaterializeDataType type) { + return type == MaterializeDataType.INT; + } + + @Override + public MaterializeDataType[] getInputTypesForReturnType(MaterializeDataType returnType, int nrArguments) { + return new MaterializeDataType[] { returnType }; + } + + }, + LOWER(1, "lower") { + + @Override + public MaterializeConstant apply(MaterializeConstant[] evaluatedArgs, MaterializeExpression... args) { + if (evaluatedArgs[0].isNull()) { + return MaterializeConstant.createNullConstant(); + } else { + String text = evaluatedArgs[0].asString(); + return MaterializeConstant.createTextConstant(text.toLowerCase()); + } + } + + @Override + public boolean supportsReturnType(MaterializeDataType type) { + return type == MaterializeDataType.TEXT; + } + + @Override + public MaterializeDataType[] getInputTypesForReturnType(MaterializeDataType returnType, int nrArguments) { + return new MaterializeDataType[] { MaterializeDataType.TEXT }; + } + + }, + LENGTH(1, "length") { + @Override + public MaterializeConstant apply(MaterializeConstant[] evaluatedArgs, MaterializeExpression... args) { + if (evaluatedArgs[0].isNull()) { + return MaterializeConstant.createNullConstant(); + } + String text = evaluatedArgs[0].asString(); + return MaterializeConstant.createIntConstant(text.length()); + } + + @Override + public boolean supportsReturnType(MaterializeDataType type) { + return type == MaterializeDataType.INT; + } + + @Override + public MaterializeDataType[] getInputTypesForReturnType(MaterializeDataType returnType, int nrArguments) { + return new MaterializeDataType[] { MaterializeDataType.TEXT }; + } + }, + UPPER(1, "upper") { + + @Override + public MaterializeConstant apply(MaterializeConstant[] evaluatedArgs, MaterializeExpression... args) { + if (evaluatedArgs[0].isNull()) { + return MaterializeConstant.createNullConstant(); + } else { + String text = evaluatedArgs[0].asString(); + return MaterializeConstant.createTextConstant(text.toUpperCase()); + } + } + + @Override + public boolean supportsReturnType(MaterializeDataType type) { + return type == MaterializeDataType.TEXT; + } + + @Override + public MaterializeDataType[] getInputTypesForReturnType(MaterializeDataType returnType, int nrArguments) { + return new MaterializeDataType[] { MaterializeDataType.TEXT }; + } + + }, + NUM_NONNULLS(1, "num_nonnulls") { + @Override + public MaterializeConstant apply(MaterializeConstant[] args, MaterializeExpression... origArgs) { + int nr = 0; + for (MaterializeConstant c : args) { + if (!c.isNull()) { + nr++; + } + } + return MaterializeConstant.createIntConstant(nr); + } + + @Override + public MaterializeDataType[] getInputTypesForReturnType(MaterializeDataType returnType, int nrArguments) { + return getRandomTypes(nrArguments); + } + + @Override + public boolean supportsReturnType(MaterializeDataType type) { + return type == MaterializeDataType.INT; + } + + @Override + public boolean isVariadic() { + return true; + } + + }, + NUM_NULLS(1, "num_nulls") { + @Override + public MaterializeConstant apply(MaterializeConstant[] args, MaterializeExpression... origArgs) { + int nr = 0; + for (MaterializeConstant c : args) { + if (c.isNull()) { + nr++; + } + } + return MaterializeConstant.createIntConstant(nr); + } + + @Override + public MaterializeDataType[] getInputTypesForReturnType(MaterializeDataType returnType, int nrArguments) { + return getRandomTypes(nrArguments); + } + + @Override + public boolean supportsReturnType(MaterializeDataType type) { + return type == MaterializeDataType.INT; + } + + @Override + public boolean isVariadic() { + return true; + } + + }; + + private String functionName; + final int nrArgs; + private final boolean variadic; + + public MaterializeDataType[] getRandomTypes(int nr) { + MaterializeDataType[] types = new MaterializeDataType[nr]; + for (int i = 0; i < types.length; i++) { + types[i] = MaterializeDataType.getRandomType(); + } + return types; + } + + MaterializeFunctionWithResult(int nrArgs, String functionName) { + this.nrArgs = nrArgs; + this.functionName = functionName; + this.variadic = false; + } + + /** + * Gets the number of arguments if the function is non-variadic. If the function is variadic, the minimum number + * of arguments is returned. + * + * @return the number of arguments + */ + public int getNrArgs() { + return nrArgs; + } + + public abstract MaterializeConstant apply(MaterializeConstant[] evaluatedArgs, MaterializeExpression... args); + + @Override + public String toString() { + return functionName; + } + + public boolean isVariadic() { + return variadic; + } + + public String getName() { + return functionName; + } + + public abstract boolean supportsReturnType(MaterializeDataType type); + + public abstract MaterializeDataType[] getInputTypesForReturnType(MaterializeDataType returnType, + int nrArguments); + + public boolean checkArguments(MaterializeExpression... constants) { + return true; + } + + } + + @Override + public MaterializeConstant getExpectedValue() { + if (functionWithKnownResult == null) { + return null; + } + MaterializeConstant[] constants = new MaterializeConstant[args.length]; + for (int i = 0; i < constants.length; i++) { + constants[i] = args[i].getExpectedValue(); + if (constants[i] == null) { + return null; + } + } + return functionWithKnownResult.apply(constants, args); + } + + @Override + public MaterializeDataType getExpressionType() { + return returnType; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeFunctionWithUnknownResult.java b/src/sqlancer/materialize/ast/MaterializeFunctionWithUnknownResult.java new file mode 100644 index 000000000..fb34c2247 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeFunctionWithUnknownResult.java @@ -0,0 +1,148 @@ +package sqlancer.materialize.ast; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.gen.MaterializeExpressionGenerator; + +public enum MaterializeFunctionWithUnknownResult { + + CURRENT_DATABASE("current_database", MaterializeDataType.TEXT), // name + CURRENT_SCHEMA("current_schema", MaterializeDataType.TEXT), // name + PG_BACKEND_PID("pg_backend_pid", MaterializeDataType.INT), + PG_CURRENT_LOGFILE("pg_current_logfile", MaterializeDataType.TEXT), + PG_IS_OTHER_TEMP_SCHEMA("pg_is_other_temp_schema", MaterializeDataType.BOOLEAN), + PG_JIT_AVAILABLE("pg_jit_available", MaterializeDataType.BOOLEAN), + PG_NOTIFICATION_QUEUE_USAGE("pg_notification_queue_usage", MaterializeDataType.REAL), + PG_TRIGGER_DEPTH("pg_trigger_depth", MaterializeDataType.INT), VERSION("version", MaterializeDataType.TEXT), + + // + TO_CHAR("to_char", MaterializeDataType.TEXT, MaterializeDataType.TEXT, MaterializeDataType.TEXT) { + @Override + public MaterializeExpression[] getArguments(MaterializeDataType returnType, MaterializeExpressionGenerator gen, + int depth) { + MaterializeExpression[] args = super.getArguments(returnType, gen, depth); + args[0] = gen.generateExpression(MaterializeDataType.getRandomType()); + return args; + } + }, + + // String functions + ASCII("ascii", MaterializeDataType.INT, MaterializeDataType.TEXT), + BTRIM("btrim", MaterializeDataType.TEXT, MaterializeDataType.TEXT, MaterializeDataType.TEXT), + CHR("chr", MaterializeDataType.TEXT, MaterializeDataType.INT), + CONVERT_FROM("convert_from", MaterializeDataType.TEXT, MaterializeDataType.TEXT, MaterializeDataType.TEXT) { + @Override + public MaterializeExpression[] getArguments(MaterializeDataType returnType, MaterializeExpressionGenerator gen, + int depth) { + MaterializeExpression[] args = super.getArguments(returnType, gen, depth); + args[1] = MaterializeConstant.createTextConstant(Randomly.fromOptions("UTF8", "LATIN1")); + return args; + } + }, + INITCAP("initcap", MaterializeDataType.TEXT, MaterializeDataType.TEXT), + LEFT("left", MaterializeDataType.TEXT, MaterializeDataType.INT, MaterializeDataType.TEXT), + LOWER("lower", MaterializeDataType.TEXT, MaterializeDataType.TEXT), + MD5("md5", MaterializeDataType.TEXT, MaterializeDataType.TEXT), + UPPER("upper", MaterializeDataType.TEXT, MaterializeDataType.TEXT), + QUOTE_LITERAL("quote_literal", MaterializeDataType.TEXT, MaterializeDataType.TEXT), + QUOTE_IDENT("quote_ident", MaterializeDataType.TEXT, MaterializeDataType.TEXT), + REGEX_REPLACE("regex_replace", MaterializeDataType.TEXT, MaterializeDataType.TEXT, MaterializeDataType.TEXT), + REPLACE("replace", MaterializeDataType.TEXT, MaterializeDataType.TEXT, MaterializeDataType.TEXT), + REVERSE("reverse", MaterializeDataType.TEXT, MaterializeDataType.TEXT), + RIGHT("right", MaterializeDataType.TEXT, MaterializeDataType.TEXT, MaterializeDataType.INT), + RPAD("rpad", MaterializeDataType.TEXT, MaterializeDataType.INT, MaterializeDataType.TEXT), + RTRIM("rtrim", MaterializeDataType.TEXT, MaterializeDataType.TEXT), + SPLIT_PART("split_part", MaterializeDataType.TEXT, MaterializeDataType.TEXT, MaterializeDataType.INT), + STRPOS("strpos", MaterializeDataType.INT, MaterializeDataType.TEXT, MaterializeDataType.TEXT), + SUBSTR("substr", MaterializeDataType.TEXT, MaterializeDataType.TEXT, MaterializeDataType.INT, + MaterializeDataType.INT), + TO_ASCII("to_ascii", MaterializeDataType.TEXT, MaterializeDataType.TEXT), + TO_HEX("to_hex", MaterializeDataType.INT, MaterializeDataType.TEXT), + TRANSLATE("translate", MaterializeDataType.TEXT, MaterializeDataType.TEXT, MaterializeDataType.TEXT, + MaterializeDataType.TEXT), + // mathematical functions + ABS("abs", MaterializeDataType.REAL, MaterializeDataType.REAL), + CBRT("cbrt", MaterializeDataType.REAL, MaterializeDataType.REAL), CEILING("ceiling", MaterializeDataType.REAL), // + DEGREES("degrees", MaterializeDataType.REAL), EXP("exp", MaterializeDataType.REAL), + LN("ln", MaterializeDataType.REAL), LOG("log", MaterializeDataType.REAL), + LOG2("log", MaterializeDataType.REAL, MaterializeDataType.REAL), PI("pi", MaterializeDataType.REAL), + POWER("power", MaterializeDataType.REAL, MaterializeDataType.REAL), + TRUNC("trunc", MaterializeDataType.REAL, MaterializeDataType.INT), + TRUNC2("trunc", MaterializeDataType.REAL, MaterializeDataType.INT, MaterializeDataType.REAL), + FLOOR("floor", MaterializeDataType.REAL), + + // trigonometric functions - complete + ACOS("acos", MaterializeDataType.REAL), // + ACOSD("acosd", MaterializeDataType.REAL), // + ASIN("asin", MaterializeDataType.REAL), // + ASIND("asind", MaterializeDataType.REAL), // + ATAN("atan", MaterializeDataType.REAL), // + ATAND("atand", MaterializeDataType.REAL), // + ATAN2("atan2", MaterializeDataType.REAL, MaterializeDataType.REAL), // + ATAN2D("atan2d", MaterializeDataType.REAL, MaterializeDataType.REAL), // + COS("cos", MaterializeDataType.REAL), // + COSD("cosd", MaterializeDataType.REAL), // + COT("cot", MaterializeDataType.REAL), // + COTD("cotd", MaterializeDataType.REAL), // + SIN("sin", MaterializeDataType.REAL), // + SIND("sind", MaterializeDataType.REAL), // + TAN("tan", MaterializeDataType.REAL), // + TAND("tand", MaterializeDataType.REAL), // + + // hyperbolic functions - complete + SINH("sinh", MaterializeDataType.REAL), // + COSH("cosh", MaterializeDataType.REAL), // + TANH("tanh", MaterializeDataType.REAL), // + ASINH("asinh", MaterializeDataType.REAL), // + ACOSH("acosh", MaterializeDataType.REAL), // + ATANH("atanh", MaterializeDataType.REAL), // + + GET_BIT("get_bit", MaterializeDataType.INT, MaterializeDataType.TEXT, MaterializeDataType.INT), + GET_BYTE("get_byte", MaterializeDataType.INT, MaterializeDataType.TEXT, MaterializeDataType.INT), + + GET_COLUMN_SIZE("get_column_size", MaterializeDataType.INT, MaterializeDataType.TEXT); + + private String functionName; + private MaterializeDataType returnType; + private MaterializeDataType[] argTypes; + + MaterializeFunctionWithUnknownResult(String functionName, MaterializeDataType returnType, + MaterializeDataType... indexType) { + this.functionName = functionName; + this.returnType = returnType; + this.argTypes = indexType.clone(); + + } + + public boolean isCompatibleWithReturnType(MaterializeDataType t) { + return t == returnType; + } + + public MaterializeExpression[] getArguments(MaterializeDataType returnType, MaterializeExpressionGenerator gen, + int depth) { + MaterializeExpression[] args = new MaterializeExpression[argTypes.length]; + for (int i = 0; i < args.length; i++) { + args[i] = gen.generateExpression(depth, argTypes[i]); + } + return args; + + } + + public String getName() { + return functionName; + } + + public static List getSupportedFunctions(MaterializeDataType type) { + List functions = new ArrayList<>(); + for (MaterializeFunctionWithUnknownResult func : values()) { + if (func.isCompatibleWithReturnType(type)) { + functions.add(func); + } + } + return functions; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeInOperation.java b/src/sqlancer/materialize/ast/MaterializeInOperation.java new file mode 100644 index 000000000..270a973dc --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeInOperation.java @@ -0,0 +1,66 @@ +package sqlancer.materialize.ast; + +import java.util.List; + +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializeInOperation implements MaterializeExpression { + + private final MaterializeExpression expr; + private final List listElements; + private final boolean isTrue; + + public MaterializeInOperation(MaterializeExpression expr, List listElements, + boolean isTrue) { + this.expr = expr; + this.listElements = listElements; + this.isTrue = isTrue; + } + + public MaterializeExpression getExpr() { + return expr; + } + + public List getListElements() { + return listElements; + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeConstant leftValue = expr.getExpectedValue(); + if (leftValue == null) { + return null; + } + if (leftValue.isNull()) { + return MaterializeConstant.createNullConstant(); + } + boolean isNull = false; + for (MaterializeExpression expr : getListElements()) { + MaterializeConstant rightExpectedValue = expr.getExpectedValue(); + if (rightExpectedValue == null) { + return null; + } + if (rightExpectedValue.isNull()) { + isNull = true; + } else if (rightExpectedValue.isEquals(this.expr.getExpectedValue()).isBoolean() + && rightExpectedValue.isEquals(this.expr.getExpectedValue()).asBoolean()) { + return MaterializeConstant.createBooleanConstant(isTrue); + } + } + + if (isNull) { + return MaterializeConstant.createNullConstant(); + } else { + return MaterializeConstant.createBooleanConstant(!isTrue); + } + } + + public boolean isTrue() { + return isTrue; + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } +} diff --git a/src/sqlancer/materialize/ast/MaterializeJoin.java b/src/sqlancer/materialize/ast/MaterializeJoin.java new file mode 100644 index 000000000..ef8f1ac61 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeJoin.java @@ -0,0 +1,58 @@ +package sqlancer.materialize.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; + +public class MaterializeJoin + implements MaterializeExpression, Join { + + public enum MaterializeJoinType { + INNER, LEFT, RIGHT, FULL, CROSS; + + public static MaterializeJoinType getRandom() { + return Randomly.fromOptions(values()); + } + + } + + private MaterializeExpression onClause; + private final MaterializeExpression tableReference; + private final MaterializeJoinType type; + + public MaterializeJoin(MaterializeExpression tableReference, MaterializeExpression onClause, + MaterializeJoinType type) { + this.tableReference = tableReference; + this.onClause = onClause; + this.type = type; + } + + public MaterializeExpression getTableReference() { + return tableReference; + } + + public MaterializeExpression getOnClause() { + return onClause; + } + + public MaterializeJoinType getType() { + return type; + } + + @Override + public MaterializeDataType getExpressionType() { + throw new AssertionError(); + } + + @Override + public MaterializeConstant getExpectedValue() { + throw new AssertionError(); + } + + @Override + public void setOnClause(MaterializeExpression onClause) { + this.onClause = onClause; + } +} diff --git a/src/sqlancer/materialize/ast/MaterializeLikeOperation.java b/src/sqlancer/materialize/ast/MaterializeLikeOperation.java new file mode 100644 index 000000000..b14209993 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeLikeOperation.java @@ -0,0 +1,38 @@ +package sqlancer.materialize.ast; + +import sqlancer.LikeImplementationHelper; +import sqlancer.common.ast.BinaryNode; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializeLikeOperation extends BinaryNode implements MaterializeExpression { + + public MaterializeLikeOperation(MaterializeExpression left, MaterializeExpression right) { + super(left, right); + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeConstant leftVal = getLeft().getExpectedValue(); + MaterializeConstant rightVal = getRight().getExpectedValue(); + if (leftVal == null || rightVal == null) { + return null; + } + if (leftVal.isNull() || rightVal.isNull()) { + return MaterializeConstant.createNullConstant(); + } else { + boolean val = LikeImplementationHelper.match(leftVal.asString(), rightVal.asString(), 0, 0, true); + return MaterializeConstant.createBooleanConstant(val); + } + } + + @Override + public String getOperatorRepresentation() { + return "LIKE"; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeOrderByTerm.java b/src/sqlancer/materialize/ast/MaterializeOrderByTerm.java new file mode 100644 index 000000000..d4704c87f --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeOrderByTerm.java @@ -0,0 +1,42 @@ +package sqlancer.materialize.ast; + +import sqlancer.Randomly; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializeOrderByTerm implements MaterializeExpression { + + private final MaterializeOrder order; + private final MaterializeExpression expr; + + public enum MaterializeOrder { + ASC, DESC; + + public static MaterializeOrder getRandomOrder() { + return Randomly.fromOptions(MaterializeOrder.values()); + } + } + + public MaterializeOrderByTerm(MaterializeExpression expr, MaterializeOrder order) { + this.expr = expr; + this.order = order; + } + + public MaterializeOrder getOrder() { + return order; + } + + public MaterializeExpression getExpr() { + return expr; + } + + @Override + public MaterializeConstant getExpectedValue() { + throw new AssertionError(this); + } + + @Override + public MaterializeDataType getExpressionType() { + return null; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializePOSIXRegularExpression.java b/src/sqlancer/materialize/ast/MaterializePOSIXRegularExpression.java new file mode 100644 index 000000000..127bafe79 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializePOSIXRegularExpression.java @@ -0,0 +1,65 @@ +package sqlancer.materialize.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializePOSIXRegularExpression implements MaterializeExpression { + + private MaterializeExpression string; + private MaterializeExpression regex; + private POSIXRegex op; + + public enum POSIXRegex implements Operator { + MATCH_CASE_SENSITIVE("~"), MATCH_CASE_INSENSITIVE("~*"), NOT_MATCH_CASE_SENSITIVE("!~"), + NOT_MATCH_CASE_INSENSITIVE("!~*"); + + private String repr; + + POSIXRegex(String repr) { + this.repr = repr; + } + + public String getStringRepresentation() { + return repr; + } + + public static POSIXRegex getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return toString(); + } + } + + public MaterializePOSIXRegularExpression(MaterializeExpression string, MaterializeExpression regex, POSIXRegex op) { + this.string = string; + this.regex = regex; + this.op = op; + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + + @Override + public MaterializeConstant getExpectedValue() { + return null; + } + + public MaterializeExpression getRegex() { + return regex; + } + + public MaterializeExpression getString() { + return string; + } + + public POSIXRegex getOp() { + return op; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializePostfixOperation.java b/src/sqlancer/materialize/ast/MaterializePostfixOperation.java new file mode 100644 index 000000000..7bee2e38c --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializePostfixOperation.java @@ -0,0 +1,151 @@ +package sqlancer.materialize.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializePostfixOperation implements MaterializeExpression { + + private final MaterializeExpression expr; + private final PostfixOperator op; + private final String operatorTextRepresentation; + + public enum PostfixOperator implements Operator { + IS_NULL("IS NULL", "ISNULL") { + @Override + public MaterializeConstant apply(MaterializeConstant expectedValue) { + return MaterializeConstant.createBooleanConstant(expectedValue.isNull()); + } + + @Override + public MaterializeDataType[] getInputDataTypes() { + return MaterializeDataType.values(); + } + + }, + IS_UNKNOWN("IS UNKNOWN") { + @Override + public MaterializeConstant apply(MaterializeConstant expectedValue) { + return MaterializeConstant.createBooleanConstant(expectedValue.isNull()); + } + + @Override + public MaterializeDataType[] getInputDataTypes() { + return new MaterializeDataType[] { MaterializeDataType.BOOLEAN }; + } + }, + + IS_NOT_NULL("IS NOT NULL") { + + @Override + public MaterializeConstant apply(MaterializeConstant expectedValue) { + return MaterializeConstant.createBooleanConstant(!expectedValue.isNull()); + } + + @Override + public MaterializeDataType[] getInputDataTypes() { + return MaterializeDataType.values(); + } + + }, + IS_NOT_UNKNOWN("IS NOT UNKNOWN") { + @Override + public MaterializeConstant apply(MaterializeConstant expectedValue) { + return MaterializeConstant.createBooleanConstant(!expectedValue.isNull()); + } + + @Override + public MaterializeDataType[] getInputDataTypes() { + return new MaterializeDataType[] { MaterializeDataType.BOOLEAN }; + } + }, + IS_TRUE("IS TRUE") { + + @Override + public MaterializeConstant apply(MaterializeConstant expectedValue) { + if (expectedValue.isNull()) { + return MaterializeConstant.createFalse(); + } else { + return MaterializeConstant + .createBooleanConstant(expectedValue.cast(MaterializeDataType.BOOLEAN).asBoolean()); + } + } + + @Override + public MaterializeDataType[] getInputDataTypes() { + return new MaterializeDataType[] { MaterializeDataType.BOOLEAN }; + } + + }, + IS_FALSE("IS FALSE") { + + @Override + public MaterializeConstant apply(MaterializeConstant expectedValue) { + if (expectedValue.isNull()) { + return MaterializeConstant.createFalse(); + } else { + return MaterializeConstant + .createBooleanConstant(!expectedValue.cast(MaterializeDataType.BOOLEAN).asBoolean()); + } + } + + @Override + public MaterializeDataType[] getInputDataTypes() { + return new MaterializeDataType[] { MaterializeDataType.BOOLEAN }; + } + + }; + + private String[] textRepresentations; + + PostfixOperator(String... textRepresentations) { + this.textRepresentations = textRepresentations.clone(); + } + + public abstract MaterializeConstant apply(MaterializeConstant expectedValue); + + public abstract MaterializeDataType[] getInputDataTypes(); + + public static PostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return toString(); + } + } + + public MaterializePostfixOperation(MaterializeExpression expr, PostfixOperator op) { + this.expr = expr; + this.operatorTextRepresentation = Randomly.fromOptions(op.textRepresentations); + this.op = op; + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeConstant expectedValue = expr.getExpectedValue(); + if (expectedValue == null) { + return null; + } + return op.apply(expectedValue); + } + + public String getOperatorTextRepresentation() { + return operatorTextRepresentation; + } + + public static MaterializeExpression create(MaterializeExpression expr, PostfixOperator op) { + return new MaterializePostfixOperation(expr, op); + } + + public MaterializeExpression getExpression() { + return expr; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializePostfixText.java b/src/sqlancer/materialize/ast/MaterializePostfixText.java new file mode 100644 index 000000000..a64b03c6d --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializePostfixText.java @@ -0,0 +1,37 @@ +package sqlancer.materialize.ast; + +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializePostfixText implements MaterializeExpression { + + private final MaterializeExpression expr; + private final String text; + private final MaterializeConstant expectedValue; + private final MaterializeDataType type; + + public MaterializePostfixText(MaterializeExpression expr, String text, MaterializeConstant expectedValue, + MaterializeDataType type) { + this.expr = expr; + this.text = text; + this.expectedValue = expectedValue; + this.type = type; + } + + public MaterializeExpression getExpr() { + return expr; + } + + public String getText() { + return text; + } + + @Override + public MaterializeConstant getExpectedValue() { + return expectedValue; + } + + @Override + public MaterializeDataType getExpressionType() { + return type; + } +} diff --git a/src/sqlancer/materialize/ast/MaterializePrefixOperation.java b/src/sqlancer/materialize/ast/MaterializePrefixOperation.java new file mode 100644 index 000000000..456c65b45 --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializePrefixOperation.java @@ -0,0 +1,119 @@ +package sqlancer.materialize.ast; + +import sqlancer.IgnoreMeException; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializePrefixOperation implements MaterializeExpression { + + public enum PrefixOperator implements Operator { + NOT("NOT", MaterializeDataType.BOOLEAN) { + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + + @Override + protected MaterializeConstant getExpectedValue(MaterializeConstant expectedValue) { + if (expectedValue.isNull()) { + return MaterializeConstant.createNullConstant(); + } else { + return MaterializeConstant + .createBooleanConstant(!expectedValue.cast(MaterializeDataType.BOOLEAN).asBoolean()); + } + } + }, + UNARY_PLUS("+", MaterializeDataType.INT) { + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.INT; + } + + @Override + protected MaterializeConstant getExpectedValue(MaterializeConstant expectedValue) { + // TODO: actual converts to double precision + return expectedValue; + } + + }, + UNARY_MINUS("-", MaterializeDataType.INT) { + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.INT; + } + + @Override + protected MaterializeConstant getExpectedValue(MaterializeConstant expectedValue) { + if (expectedValue.isNull()) { + // TODO + throw new IgnoreMeException(); + } + if (expectedValue.isInt() && expectedValue.asInt() == Long.MIN_VALUE) { + throw new IgnoreMeException(); + } + try { + return MaterializeConstant.createIntConstant(-expectedValue.asInt()); + } catch (UnsupportedOperationException e) { + return null; + } + } + + }; + + private String textRepresentation; + private MaterializeDataType[] dataTypes; + + PrefixOperator(String textRepresentation, MaterializeDataType... dataTypes) { + this.textRepresentation = textRepresentation; + this.dataTypes = dataTypes.clone(); + } + + public abstract MaterializeDataType getExpressionType(); + + protected abstract MaterializeConstant getExpectedValue(MaterializeConstant expectedValue); + + @Override + public String getTextRepresentation() { + return toString(); + } + + } + + private final MaterializeExpression expr; + private final PrefixOperator op; + + public MaterializePrefixOperation(MaterializeExpression expr, PrefixOperator op) { + this.expr = expr; + this.op = op; + } + + @Override + public MaterializeDataType getExpressionType() { + return op.getExpressionType(); + } + + @Override + public MaterializeConstant getExpectedValue() { + MaterializeConstant expectedValue = expr.getExpectedValue(); + if (expectedValue == null) { + return null; + } + return op.getExpectedValue(expectedValue); + } + + public MaterializeDataType[] getInputDataTypes() { + return op.dataTypes; + } + + public String getTextRepresentation() { + return op.textRepresentation; + } + + public MaterializeExpression getExpression() { + return expr; + } + +} diff --git a/src/sqlancer/materialize/ast/MaterializeSelect.java b/src/sqlancer/materialize/ast/MaterializeSelect.java new file mode 100644 index 000000000..db5b85a2f --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeSelect.java @@ -0,0 +1,145 @@ +package sqlancer.materialize.ast; + +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; +import sqlancer.materialize.MaterializeVisitor; + +public class MaterializeSelect extends SelectBase implements MaterializeExpression, + Select { + + private SelectType selectOption = SelectType.ALL; + private List joinClauses = Collections.emptyList(); + private MaterializeExpression distinctOnClause; + private ForClause forClause; + + public enum ForClause { + UPDATE("UPDATE"), NO_KEY_UPDATE("NO KEY UPDATE"), SHARE("SHARE"), KEY_SHARE("KEY SHARE"); + + private final String textRepresentation; + + ForClause(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public String getTextRepresentation() { + return textRepresentation; + } + + public static ForClause getRandom() { + return Randomly.fromOptions(values()); + } + } + + public static class MaterializeFromTable implements MaterializeExpression { + private final MaterializeTable t; + private final boolean only; + + public MaterializeFromTable(MaterializeTable t, boolean only) { + this.t = t; + this.only = only; + } + + public MaterializeTable getTable() { + return t; + } + + public boolean isOnly() { + return only; + } + + @Override + public MaterializeDataType getExpressionType() { + return null; + } + } + + public static class MaterializeSubquery implements MaterializeExpression { + private final MaterializeSelect s; + private final String name; + + public MaterializeSubquery(MaterializeSelect s, String name) { + this.s = s; + this.name = name; + } + + public MaterializeSelect getSelect() { + return s; + } + + public String getName() { + return name; + } + + @Override + public MaterializeDataType getExpressionType() { + return null; + } + } + + public enum SelectType { + DISTINCT, ALL; + + public static SelectType getRandom() { + return Randomly.fromOptions(values()); + } + } + + public void setSelectType(SelectType fromOptions) { + this.setSelectOption(fromOptions); + } + + public void setDistinctOnClause(MaterializeExpression distinctOnClause) { + if (selectOption != SelectType.DISTINCT) { + throw new IllegalArgumentException(); + } + this.distinctOnClause = distinctOnClause; + } + + public SelectType getSelectOption() { + return selectOption; + } + + public void setSelectOption(SelectType fromOptions) { + this.selectOption = fromOptions; + } + + @Override + public MaterializeDataType getExpressionType() { + return null; + } + + @Override + public void setJoinClauses(List joinStatements) { + this.joinClauses = joinStatements; + + } + + @Override + public List getJoinClauses() { + return joinClauses; + } + + public MaterializeExpression getDistinctOnClause() { + return distinctOnClause; + } + + public void setForClause(ForClause forClause) { + this.forClause = forClause; + } + + public ForClause getForClause() { + return forClause; + } + + @Override + public String asString() { + return MaterializeVisitor.asString(this); + } +} diff --git a/src/sqlancer/materialize/ast/MaterializeSimilarTo.java b/src/sqlancer/materialize/ast/MaterializeSimilarTo.java new file mode 100644 index 000000000..dd675050d --- /dev/null +++ b/src/sqlancer/materialize/ast/MaterializeSimilarTo.java @@ -0,0 +1,40 @@ +package sqlancer.materialize.ast; + +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; + +public class MaterializeSimilarTo implements MaterializeExpression { + + private final MaterializeExpression string; + private final MaterializeExpression similarTo; + private final MaterializeExpression escapeCharacter; + + public MaterializeSimilarTo(MaterializeExpression string, MaterializeExpression similarTo, + MaterializeExpression escapeCharacter) { + this.string = string; + this.similarTo = similarTo; + this.escapeCharacter = escapeCharacter; + } + + public MaterializeExpression getString() { + return string; + } + + public MaterializeExpression getSimilarTo() { + return similarTo; + } + + public MaterializeExpression getEscapeCharacter() { + return escapeCharacter; + } + + @Override + public MaterializeDataType getExpressionType() { + return MaterializeDataType.BOOLEAN; + } + + @Override + public MaterializeConstant getExpectedValue() { + return null; + } + +} diff --git a/src/sqlancer/materialize/gen/MaterializeCommon.java b/src/sqlancer/materialize/gen/MaterializeCommon.java new file mode 100644 index 000000000..d21c8f81c --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeCommon.java @@ -0,0 +1,416 @@ +package sqlancer.materialize.gen; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeProvider; +import sqlancer.materialize.MaterializeSchema; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; +import sqlancer.materialize.MaterializeVisitor; + +public final class MaterializeCommon { + + private MaterializeCommon() { + } + + public static List getCommonFetchErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("FULL JOIN is only supported with merge-joinable or hash-joinable join conditions"); + errors.add("but it cannot be referenced from this part of the query"); + errors.add("missing FROM-clause entry for table"); + + errors.add("canceling statement due to statement timeout"); + + errors.add("non-integer constant in GROUP BY"); + errors.add("must appear in the GROUP BY clause or be used in an aggregate function"); + errors.add("GROUP BY position"); + errors.add("result exceeds max size of"); + + errors.add("does not exist"); + errors.add("aggregate functions are not allowed in"); + errors.add("is only defined for finite arguments"); + + return errors; + } + + public static void addCommonFetchErrors(ExpectedErrors errors) { + errors.addAll(getCommonFetchErrors()); + } + + public static List getCommonTableErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("is not commutative"); // exclude + errors.add("operator requires run-time type coercion"); // exclude + + return errors; + } + + public static void addCommonTableErrors(ExpectedErrors errors) { + errors.addAll(getCommonTableErrors()); + } + + public static List getCommonExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("You might need to add explicit type casts"); + errors.add("invalid regular expression"); + errors.add("could not determine which collation to use"); + errors.add("invalid regular expression"); + errors.add("operator does not exist"); + errors.add("quantifier operand invalid"); + errors.add("collation mismatch"); + errors.add("collations are not supported"); + errors.add("operator is not unique"); + errors.add("is not a valid binary digit"); + errors.add("invalid hexadecimal digit"); + errors.add("invalid hexadecimal data: odd number of digits"); + errors.add("zero raised to a negative power is undefined"); + errors.add("cannot convert infinity to numeric"); + errors.add("division by zero"); + errors.add("invalid input syntax for type money"); + errors.add("invalid input syntax for type"); + errors.add("cannot cast type"); + errors.add("value overflows numeric format"); + errors.add("numeric field overflow"); + errors.add("LIKE pattern must not end with escape character"); + errors.add("is of type boolean but expression is of type text"); + errors.add("a negative number raised to a non-integer power yields a complex result"); + errors.add("could not determine polymorphic type because input has type unknown"); + errors.add("character number must be positive"); + errors.add("unterminated escape sequence"); + errors.add("cannot be matched"); + errors.add("clause must have type"); // "not" in having doesn't work + errors.add("argument must have type"); // "not" in having doesn't work + errors.add("CAST does not support casting from"); + errors.add("aggregate functions are not allowed in"); + errors.add("only defined for finite arguments"); + errors.add("unable to parse column reference in GROUP BY clause"); // TODO + errors.addAll(getToCharFunctionErrors()); + errors.addAll(getBitStringOperationErrors()); + errors.addAll(getFunctionErrors()); + errors.addAll(getCommonRangeExpressionErrors()); + errors.addAll(getCommonRegexExpressionErrors()); + + return errors; + } + + public static void addCommonExpressionErrors(ExpectedErrors errors) { + errors.addAll(getCommonExpressionErrors()); + } + + private static List getToCharFunctionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("multiple decimal points"); + errors.add("and decimal point together"); + errors.add("multiple decimal points"); + errors.add("cannot use \"S\" twice"); + errors.add("must be ahead of \"PR\""); + errors.add("cannot use \"S\" and \"PL\"/\"MI\"/\"SG\"/\"PR\" together"); + errors.add("cannot use \"S\" and \"SG\" together"); + errors.add("cannot use \"S\" and \"MI\" together"); + errors.add("cannot use \"S\" and \"PL\" together"); + errors.add("cannot use \"PR\" and \"S\"/\"PL\"/\"MI\"/\"SG\" together"); + errors.add("is not a number"); + + return errors; + } + + private static List getBitStringOperationErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("cannot XOR bit strings of different sizes"); + errors.add("cannot AND bit strings of different sizes"); + errors.add("cannot OR bit strings of different sizes"); + errors.add("must be type boolean, not type text"); + + return errors; + } + + private static List getFunctionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("out of valid range"); // get_bit/get_byte + errors.add("cannot take logarithm of a negative number"); + errors.add("cannot take logarithm of zero"); + errors.add("requested character too large for encoding"); // chr + errors.add("null character not permitted"); // chr + errors.add("requested character not valid for encoding"); // chr + errors.add("requested length too large"); // repeat + errors.add("invalid memory alloc request size"); // repeat + errors.add("encoding conversion from UTF8 to ASCII not supported"); // to_ascii + errors.add("negative substring length not allowed"); // substr + errors.add("invalid mask length"); // set_masklen + + return errors; + } + + private static List getCommonRegexExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("is not a valid hexadecimal digit"); + + return errors; + } + + public static List getCommonRangeExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("range lower bound must be less than or equal to range upper bound"); + errors.add("result of range difference would not be contiguous"); + errors.add("out of range"); + errors.add("malformed range literal"); + errors.add("result of range union would not be contiguous"); + + return errors; + } + + public static void addCommonRangeExpressionErrors(ExpectedErrors errors) { + errors.addAll(getCommonExpressionErrors()); + } + + public static List getCommonInsertUpdateErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("value too long for type character"); + errors.add("not found in view targetlist"); + errors.add("CAST does not support casting from"); + + return errors; + } + + public static void addCommonInsertUpdateErrors(ExpectedErrors errors) { + errors.addAll(getCommonExpressionErrors()); + } + + public static List getGroupingErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("non-integer constant in GROUP BY"); // TODO + errors.add("unable to parse column reference in GROUP BY clause"); // TODO + errors.add("must appear in the GROUP BY clause or be used in an aggregate function"); + errors.add("is not in select list"); + errors.add("aggregate functions are not allowed in"); + + return errors; + } + + public static void addGroupingErrors(ExpectedErrors errors) { + errors.addAll(getGroupingErrors()); + } + + public static boolean appendDataType(MaterializeDataType type, StringBuilder sb, boolean allowSerial, + boolean generateOnlyKnown, List opClasses) throws AssertionError { + boolean serial = false; + switch (type) { + case BOOLEAN: + sb.append("boolean"); + break; + case INT: + sb.append(Randomly.fromOptions("smallint", "integer", "bigint")); + break; + case TEXT: + if (Randomly.getBoolean()) { + sb.append("TEXT"); + } else { + if (MaterializeProvider.generateOnlyKnown || Randomly.getBoolean()) { + sb.append("VAR"); + } + sb.append("CHAR"); + sb.append("("); + sb.append(ThreadLocalRandom.current().nextInt(1, 500)); + sb.append(")"); + } + break; + case DECIMAL: + sb.append("DECIMAL"); + break; + case FLOAT: + sb.append("REAL"); + break; + case REAL: + sb.append("FLOAT"); + break; + case BIT: + sb.append("INT"); + break; + default: + throw new AssertionError(type); + } + return serial; + } + + public enum TableConstraints { + CHECK, PRIMARY_KEY, FOREIGN_KEY, EXCLUDE + } + + public static void addTableConstraints(boolean excludePrimaryKey, StringBuilder sb, MaterializeTable table, + MaterializeGlobalState globalState, ExpectedErrors errors) { + // TODO constraint name + List tableConstraints = Randomly.nonEmptySubset(TableConstraints.values()); + if (excludePrimaryKey) { + tableConstraints.remove(TableConstraints.PRIMARY_KEY); + } + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + tableConstraints.remove(TableConstraints.FOREIGN_KEY); + } + for (TableConstraints t : tableConstraints) { + sb.append(", "); + // TODO add index parameters + addTableConstraint(sb, table, globalState, t, errors); + } + } + + public static void addTableConstraint(StringBuilder sb, MaterializeTable table, MaterializeGlobalState globalState, + ExpectedErrors errors) { + addTableConstraint(sb, table, globalState, Randomly.fromOptions(TableConstraints.values()), errors); + } + + private static void addTableConstraint(StringBuilder sb, MaterializeTable table, MaterializeGlobalState globalState, + TableConstraints t, ExpectedErrors errors) { + List randomNonEmptyColumnSubset = table.getRandomNonEmptyColumnSubset(); + List otherColumns; + MaterializeCommon.addCommonExpressionErrors(errors); + switch (t) { + case CHECK: + sb.append("CHECK("); + sb.append(MaterializeVisitor.getExpressionAsString(globalState, MaterializeDataType.BOOLEAN, + table.getColumns())); + sb.append(")"); + errors.add("constraint must be added to child tables too"); + errors.add("missing FROM-clause entry for table"); + break; + case PRIMARY_KEY: + sb.append("PRIMARY KEY("); + sb.append(randomNonEmptyColumnSubset.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + sb.append(")"); + break; + case FOREIGN_KEY: + sb.append("FOREIGN KEY ("); + sb.append(randomNonEmptyColumnSubset.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + sb.append(") REFERENCES "); + MaterializeTable randomOtherTable = globalState.getSchema().getRandomTable(tab -> !tab.isView()); + sb.append(randomOtherTable.getName()); + if (randomOtherTable.getColumns().size() < randomNonEmptyColumnSubset.size()) { + throw new IgnoreMeException(); + } + otherColumns = randomOtherTable.getRandomNonEmptyColumnSubset(randomNonEmptyColumnSubset.size()); + sb.append("("); + sb.append(otherColumns.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + sb.append(")"); + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("MATCH FULL", "MATCH SIMPLE")); + } + if (Randomly.getBoolean()) { + sb.append(" ON DELETE "); + errors.add("ERROR: invalid ON DELETE action for foreign key constraint containing generated column"); + deleteOrUpdateAction(sb); + } + if (Randomly.getBoolean()) { + sb.append(" ON UPDATE "); + errors.add("invalid ON UPDATE action for foreign key constraint containing generated column"); + deleteOrUpdateAction(sb); + } + if (Randomly.getBoolean()) { + sb.append(" "); + if (Randomly.getBoolean()) { + sb.append("DEFERRABLE"); + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("INITIALLY DEFERRED", "INITIALLY IMMEDIATE")); + } + } else { + sb.append("NOT DEFERRABLE"); + } + } + break; + case EXCLUDE: + sb.append("EXCLUDE "); + sb.append("("); + // TODO [USING index_method ] + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + appendExcludeElement(sb, globalState, table.getColumns()); + sb.append(" WITH "); + appendOperator(sb, globalState.getOperators()); + } + sb.append(")"); + errors.add("is not valid"); + errors.add("no operator matches"); + errors.add("operator does not exist"); + errors.add("unknown has no default operator class"); + errors.add("exclusion constraints are not supported on partitioned tables"); + errors.add("The exclusion operator must be related to the index operator class for the constraint"); + errors.add("could not create exclusion constraint"); + // TODO: index parameters + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + sb.append("("); + sb.append(MaterializeVisitor.asString(MaterializeExpressionGenerator.generateExpression(globalState, + table.getColumns(), MaterializeDataType.BOOLEAN))); + sb.append(")"); + } + break; + default: + throw new AssertionError(t); + } + } + + private static void appendOperator(StringBuilder sb, List operators) { + sb.append(Randomly.fromList(operators)); + } + + // complete + private static void appendExcludeElement(StringBuilder sb, MaterializeGlobalState globalState, + List columns) { + if (Randomly.getBoolean()) { + // append column name + sb.append(Randomly.fromList(columns).getName()); + } else { + // append expression + sb.append("("); + sb.append(MaterializeVisitor + .asString(MaterializeExpressionGenerator.generateExpression(globalState, columns))); + sb.append(")"); + } + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromList(globalState.getOpClasses())); + } + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("ASC", "DESC")); + } + if (Randomly.getBoolean()) { + sb.append(" NULLS "); + sb.append(Randomly.fromOptions("FIRST", "LAST")); + } + } + + private static void deleteOrUpdateAction(StringBuilder sb) { + sb.append(Randomly.fromOptions("NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT")); + } + + public static String getFreeIndexName(MaterializeSchema s) { + List indexNames = s.getIndexNames(); + String candidateName; + do { + candidateName = DBMSCommon.createIndexName((int) Randomly.getNotCachedInteger(0, 100)); + } while (indexNames.contains(candidateName)); + return candidateName; + } +} diff --git a/src/sqlancer/materialize/gen/MaterializeDeleteGenerator.java b/src/sqlancer/materialize/gen/MaterializeDeleteGenerator.java new file mode 100644 index 000000000..ded3e53f4 --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeDeleteGenerator.java @@ -0,0 +1,38 @@ +package sqlancer.materialize.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; +import sqlancer.materialize.MaterializeVisitor; + +public final class MaterializeDeleteGenerator { + + private MaterializeDeleteGenerator() { + } + + public static SQLQueryAdapter create(MaterializeGlobalState globalState) { + MaterializeTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + ExpectedErrors errors = new ExpectedErrors(); + errors.add("violates foreign key constraint"); + errors.add("violates not-null constraint"); + errors.add("could not determine which collation to use for string comparison"); + StringBuilder sb = new StringBuilder("DELETE FROM"); + sb.append(" "); + sb.append(table.getName()); + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + sb.append(MaterializeVisitor.asString(MaterializeExpressionGenerator.generateExpression(globalState, + table.getColumns(), MaterializeDataType.BOOLEAN))); + } + MaterializeCommon.addCommonExpressionErrors(errors); + errors.add("out of range"); + errors.add("does not support casting"); + errors.add("invalid input syntax for"); + errors.add("division by zero"); + return new SQLQueryAdapter(sb.toString(), errors); + } + +} diff --git a/src/sqlancer/materialize/gen/MaterializeDropIndexGenerator.java b/src/sqlancer/materialize/gen/MaterializeDropIndexGenerator.java new file mode 100644 index 000000000..4acbd329a --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeDropIndexGenerator.java @@ -0,0 +1,53 @@ +package sqlancer.materialize.gen; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema.MaterializeIndex; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; + +public final class MaterializeDropIndexGenerator { + + private MaterializeDropIndexGenerator() { + } + + public static SQLQueryAdapter create(MaterializeGlobalState globalState) { + MaterializeTable randomTable = globalState.getSchema().getRandomTable(); + List indexes = randomTable.getIndexes(); + StringBuilder sb = new StringBuilder(); + sb.append("DROP INDEX "); + if (Randomly.getBoolean() || indexes.isEmpty()) { + sb.append("IF EXISTS "); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + if (indexes.isEmpty() || Randomly.getBoolean()) { + sb.append(DBMSCommon.createIndexName(Randomly.smallNumber())); + } else { + sb.append(Randomly.fromList(indexes).getIndexName()); + } + } + } else { + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(Randomly.fromList(indexes).getIndexName()); + } + } + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("CASCADE", "RESTRICT")); + } + return new SQLQueryAdapter(sb.toString(), + ExpectedErrors.from("cannot drop desired object(s) because other objects depend on them", + "cannot drop index", "does not exist"), + true); + } + +} diff --git a/src/sqlancer/materialize/gen/MaterializeExpressionGenerator.java b/src/sqlancer/materialize/gen/MaterializeExpressionGenerator.java new file mode 100644 index 000000000..f7ff76305 --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeExpressionGenerator.java @@ -0,0 +1,632 @@ +package sqlancer.materialize.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.schema.AbstractTables; +import sqlancer.materialize.MaterializeCompoundDataType; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeProvider; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeRowValue; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; +import sqlancer.materialize.MaterializeSchema.MaterializeTables; +import sqlancer.materialize.ast.MaterializeAggregate; +import sqlancer.materialize.ast.MaterializeAggregate.MaterializeAggregateFunction; +import sqlancer.materialize.ast.MaterializeBetweenOperation; +import sqlancer.materialize.ast.MaterializeBinaryArithmeticOperation; +import sqlancer.materialize.ast.MaterializeBinaryArithmeticOperation.MaterializeBinaryOperator; +import sqlancer.materialize.ast.MaterializeBinaryBitOperation; +import sqlancer.materialize.ast.MaterializeBinaryBitOperation.MaterializeBinaryBitOperator; +import sqlancer.materialize.ast.MaterializeBinaryComparisonOperation; +import sqlancer.materialize.ast.MaterializeBinaryLogicalOperation; +import sqlancer.materialize.ast.MaterializeBinaryLogicalOperation.BinaryLogicalOperator; +import sqlancer.materialize.ast.MaterializeCastOperation; +import sqlancer.materialize.ast.MaterializeColumnValue; +import sqlancer.materialize.ast.MaterializeConcatOperation; +import sqlancer.materialize.ast.MaterializeConstant; +import sqlancer.materialize.ast.MaterializeExpression; +import sqlancer.materialize.ast.MaterializeFunction; +import sqlancer.materialize.ast.MaterializeFunction.MaterializeFunctionWithResult; +import sqlancer.materialize.ast.MaterializeFunctionWithUnknownResult; +import sqlancer.materialize.ast.MaterializeInOperation; +import sqlancer.materialize.ast.MaterializeJoin; +import sqlancer.materialize.ast.MaterializeJoin.MaterializeJoinType; +import sqlancer.materialize.ast.MaterializeLikeOperation; +import sqlancer.materialize.ast.MaterializeOrderByTerm; +import sqlancer.materialize.ast.MaterializeOrderByTerm.MaterializeOrder; +import sqlancer.materialize.ast.MaterializePOSIXRegularExpression; +import sqlancer.materialize.ast.MaterializePOSIXRegularExpression.POSIXRegex; +import sqlancer.materialize.ast.MaterializePostfixOperation; +import sqlancer.materialize.ast.MaterializePostfixOperation.PostfixOperator; +import sqlancer.materialize.ast.MaterializePostfixText; +import sqlancer.materialize.ast.MaterializePrefixOperation; +import sqlancer.materialize.ast.MaterializePrefixOperation.PrefixOperator; +import sqlancer.materialize.ast.MaterializeSelect; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeFromTable; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeSubquery; +import sqlancer.materialize.ast.MaterializeSelect.SelectType; +import sqlancer.materialize.oracle.tlp.MaterializeTLPBase; + +public class MaterializeExpressionGenerator implements ExpressionGenerator, + NoRECGenerator, + TLPWhereGenerator { + + private final int maxDepth; + + private final Randomly r; + + private List columns; + + private List tables; + + private MaterializeRowValue rw; + + private boolean expectedResult; + + private MaterializeGlobalState globalState; + + private boolean allowAggregateFunctions; + + private final Map functionsAndTypes; + + private final List allowedFunctionTypes; + + public MaterializeExpressionGenerator(MaterializeGlobalState globalState) { + this.r = globalState.getRandomly(); + this.maxDepth = globalState.getOptions().getMaxExpressionDepth(); + this.globalState = globalState; + this.functionsAndTypes = globalState.getFunctionsAndTypes(); + this.allowedFunctionTypes = globalState.getAllowedFunctionTypes(); + } + + public MaterializeExpressionGenerator setColumns(List columns) { + this.columns = columns; + return this; + } + + public MaterializeExpressionGenerator setRowValue(MaterializeRowValue rw) { + this.rw = rw; + return this; + } + + public MaterializeExpression generateExpression(int depth) { + return generateExpression(depth, MaterializeDataType.getRandomType()); + } + + @Override + public List generateOrderBys() { + List orderBys = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber(); i++) { + orderBys.add(new MaterializeOrderByTerm(MaterializeColumnValue.create(Randomly.fromList(columns), null), + MaterializeOrder.getRandomOrder())); + } + return orderBys; + } + + private enum BooleanExpression { + POSTFIX_OPERATOR, NOT, BINARY_LOGICAL_OPERATOR, BINARY_COMPARISON, FUNCTION, LIKE, BETWEEN, IN_OPERATION, + POSIX_REGEX; + } + + private MaterializeExpression generateFunctionWithUnknownResult(int depth, MaterializeDataType type) { + List supportedFunctions = MaterializeFunctionWithUnknownResult + .getSupportedFunctions(type); + // filters functions by allowed type (STABLE 's', IMMUTABLE 'i', VOLATILE 'v') + supportedFunctions = supportedFunctions.stream() + .filter(f -> allowedFunctionTypes.contains(functionsAndTypes.get(f.getName()))) + .collect(Collectors.toList()); + if (supportedFunctions.isEmpty()) { + throw new IgnoreMeException(); + } + MaterializeFunctionWithUnknownResult randomFunction = Randomly.fromList(supportedFunctions); + return new MaterializeFunction(randomFunction, type, randomFunction.getArguments(type, this, depth + 1)); + } + + private MaterializeExpression generateFunctionWithKnownResult(int depth, MaterializeDataType type) { + List functions = Stream + .of(MaterializeFunction.MaterializeFunctionWithResult.values()).filter(f -> f.supportsReturnType(type)) + .collect(Collectors.toList()); + // filters functions by allowed type (STABLE 's', IMMUTABLE 'i', VOLATILE 'v') + functions = functions.stream().filter(f -> allowedFunctionTypes.contains(functionsAndTypes.get(f.getName()))) + .collect(Collectors.toList()); + if (functions.isEmpty()) { + throw new IgnoreMeException(); + } + MaterializeFunctionWithResult randomFunction = Randomly.fromList(functions); + int nrArgs = randomFunction.getNrArgs(); + if (randomFunction.isVariadic()) { + nrArgs += Randomly.smallNumber(); + } + MaterializeDataType[] argTypes = randomFunction.getInputTypesForReturnType(type, nrArgs); + MaterializeExpression[] args = new MaterializeExpression[nrArgs]; + do { + for (int i = 0; i < args.length; i++) { + args[i] = generateExpression(depth + 1, argTypes[i]); + } + } while (!randomFunction.checkArguments(args)); + return new MaterializeFunction(randomFunction, type, args); + } + + private MaterializeExpression generateBooleanExpression(int depth) { + List validOptions = new ArrayList<>(Arrays.asList(BooleanExpression.values())); + if (MaterializeProvider.generateOnlyKnown) { + validOptions.remove(BooleanExpression.POSIX_REGEX); + } + BooleanExpression option = Randomly.fromList(validOptions); + switch (option) { + case POSTFIX_OPERATOR: + PostfixOperator random = PostfixOperator.getRandom(); + return MaterializePostfixOperation + .create(generateExpression(depth + 1, Randomly.fromOptions(random.getInputDataTypes())), random); + case IN_OPERATION: + return inOperation(depth + 1); + case NOT: + return new MaterializePrefixOperation(generateExpression(depth + 1, MaterializeDataType.BOOLEAN), + PrefixOperator.NOT); + case BINARY_LOGICAL_OPERATOR: + MaterializeExpression first = generateExpression(depth + 1, MaterializeDataType.BOOLEAN); + int nr = Randomly.smallNumber() + 1; + for (int i = 0; i < nr; i++) { + first = new MaterializeBinaryLogicalOperation(first, + generateExpression(depth + 1, MaterializeDataType.BOOLEAN), BinaryLogicalOperator.getRandom()); + } + return first; + case BINARY_COMPARISON: + MaterializeDataType dataType = getMeaningfulType(); + return generateComparison(depth, dataType); + case FUNCTION: + return generateFunction(depth + 1, MaterializeDataType.BOOLEAN); + case LIKE: + return new MaterializeLikeOperation(generateExpression(depth + 1, MaterializeDataType.TEXT), + generateExpression(depth + 1, MaterializeDataType.TEXT)); + case BETWEEN: + MaterializeDataType type = getMeaningfulType(); + return new MaterializeBetweenOperation(generateExpression(depth + 1, type), + generateExpression(depth + 1, type), generateExpression(depth + 1, type), Randomly.getBoolean()); + case POSIX_REGEX: + assert !expectedResult; + return new MaterializePOSIXRegularExpression(generateExpression(depth + 1, MaterializeDataType.TEXT), + generateExpression(depth + 1, MaterializeDataType.TEXT), POSIXRegex.getRandom()); + default: + throw new AssertionError(); + } + } + + private MaterializeDataType getMeaningfulType() { + // make it more likely that the expression does not only consist of constant + // expressions + if (Randomly.getBooleanWithSmallProbability() || columns == null || columns.isEmpty()) { + return MaterializeDataType.getRandomType(); + } else { + return Randomly.fromList(columns).getType(); + } + } + + private MaterializeExpression generateFunction(int depth, MaterializeDataType type) { + if (MaterializeProvider.generateOnlyKnown || Randomly.getBoolean()) { + return generateFunctionWithKnownResult(depth, type); + } else { + return generateFunctionWithUnknownResult(depth, type); + } + } + + private MaterializeExpression generateComparison(int depth, MaterializeDataType dataType) { + MaterializeExpression leftExpr = generateExpression(depth + 1, dataType); + MaterializeExpression rightExpr = generateExpression(depth + 1, dataType); + return getComparison(leftExpr, rightExpr); + } + + private MaterializeExpression getComparison(MaterializeExpression leftExpr, MaterializeExpression rightExpr) { + return new MaterializeBinaryComparisonOperation(leftExpr, rightExpr, + MaterializeBinaryComparisonOperation.MaterializeBinaryComparisonOperator.getRandom()); + } + + private MaterializeExpression inOperation(int depth) { + MaterializeDataType type = MaterializeDataType.getRandomType(); + MaterializeExpression leftExpr = generateExpression(depth + 1, type); + List rightExpr = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + rightExpr.add(generateExpression(depth + 1, type)); + } + return new MaterializeInOperation(leftExpr, rightExpr, Randomly.getBoolean()); + } + + public static MaterializeExpression generateExpression(MaterializeGlobalState globalState, + MaterializeDataType type) { + return new MaterializeExpressionGenerator(globalState).generateExpression(0, type); + } + + public MaterializeExpression generateExpression(int depth, MaterializeDataType originalType) { + MaterializeDataType dataType = originalType; + if (dataType == MaterializeDataType.REAL && Randomly.getBoolean()) { + dataType = Randomly.fromOptions(MaterializeDataType.INT, MaterializeDataType.FLOAT); + } + if (dataType == MaterializeDataType.FLOAT && Randomly.getBoolean()) { + dataType = MaterializeDataType.INT; + } + if (!filterColumns(dataType).isEmpty() && Randomly.getBoolean()) { + return createColumnOfType(dataType); + } + return generateExpressionInternal(depth, dataType); + } + + private MaterializeExpression generateExpressionInternal(int depth, MaterializeDataType dataType) + throws AssertionError { + if (allowAggregateFunctions && Randomly.getBoolean()) { + allowAggregateFunctions = false; // aggregate function calls cannot be nested + return getAggregate(dataType); + } + if (Randomly.getBooleanWithRatherLowProbability() || depth > maxDepth) { + // generic expression + if (Randomly.getBoolean() || depth > maxDepth) { + if (Randomly.getBooleanWithRatherLowProbability()) { + return generateConstant(r, dataType); + } else { + if (filterColumns(dataType).isEmpty()) { + return generateConstant(r, dataType); + } else { + return createColumnOfType(dataType); + } + } + } else { + if (Randomly.getBoolean()) { + return new MaterializeCastOperation(generateExpression(depth + 1), getCompoundDataType(dataType)); + } else { + return generateFunctionWithUnknownResult(depth, dataType); + } + } + } else { + switch (dataType) { + case BOOLEAN: + return generateBooleanExpression(depth); + case INT: + return generateIntExpression(depth); + case TEXT: + return generateTextExpression(depth); + case DECIMAL: + case REAL: + case FLOAT: + return generateConstant(r, dataType); + case BIT: + return generateBitExpression(depth); + default: + throw new AssertionError(dataType); + } + } + } + + private static MaterializeCompoundDataType getCompoundDataType(MaterializeDataType type) { + switch (type) { + case BOOLEAN: + case DECIMAL: // TODO + case FLOAT: + case INT: + case REAL: + case BIT: + return MaterializeCompoundDataType.create(type); + case TEXT: // TODO + if (Randomly.getBoolean() || MaterializeProvider.generateOnlyKnown /* + * The PQS implementation does not check + * for size specifications + */) { + return MaterializeCompoundDataType.create(type); + } else { + return MaterializeCompoundDataType.create(type, (int) Randomly.getNotCachedInteger(1, 1000)); + } + default: + throw new AssertionError(type); + } + + } + + private enum TextExpression { + CAST, FUNCTION, CONCAT + } + + private MaterializeExpression generateTextExpression(int depth) { + TextExpression option; + List validOptions = new ArrayList<>(Arrays.asList(TextExpression.values())); + option = Randomly.fromList(validOptions); + + switch (option) { + case CAST: + return new MaterializeCastOperation(generateExpression(depth + 1), + getCompoundDataType(MaterializeDataType.TEXT)); + case FUNCTION: + return generateFunction(depth + 1, MaterializeDataType.TEXT); + case CONCAT: + return generateConcat(depth); + default: + throw new AssertionError(); + } + } + + private MaterializeExpression generateConcat(int depth) { + MaterializeExpression left = generateExpression(depth + 1, MaterializeDataType.TEXT); + MaterializeExpression right = generateExpression(depth + 1); + return new MaterializeConcatOperation(left, right); + } + + private enum BitExpression { + BINARY_OPERATION + }; + + private MaterializeExpression generateBitExpression(int depth) { + BitExpression option; + option = Randomly.fromOptions(BitExpression.values()); + switch (option) { + case BINARY_OPERATION: + return new MaterializeBinaryBitOperation(MaterializeBinaryBitOperator.getRandom(), + generateExpression(depth + 1, MaterializeDataType.BIT), + generateExpression(depth + 1, MaterializeDataType.BIT)); + default: + throw new AssertionError(); + } + } + + private enum IntExpression { + UNARY_OPERATION, FUNCTION, /* CAST, */BINARY_ARITHMETIC_EXPRESSION + } + + private MaterializeExpression generateIntExpression(int depth) { + IntExpression option; + option = Randomly.fromOptions(IntExpression.values()); + switch (option) { + case UNARY_OPERATION: + MaterializeExpression intExpression = generateExpression(depth + 1, MaterializeDataType.INT); + return new MaterializePrefixOperation(intExpression, + Randomly.getBoolean() ? PrefixOperator.UNARY_PLUS : PrefixOperator.UNARY_MINUS); + case FUNCTION: + return generateFunction(depth + 1, MaterializeDataType.INT); + case BINARY_ARITHMETIC_EXPRESSION: + return new MaterializeBinaryArithmeticOperation(generateExpression(depth + 1, MaterializeDataType.INT), + generateExpression(depth + 1, MaterializeDataType.INT), MaterializeBinaryOperator.getRandom()); + default: + throw new AssertionError(); + } + } + + private MaterializeExpression createColumnOfType(MaterializeDataType type) { + List columns = filterColumns(type); + MaterializeColumn fromList = Randomly.fromList(columns); + MaterializeConstant value = rw == null ? null : rw.getValues().get(fromList); + return MaterializeColumnValue.create(fromList, value); + } + + final List filterColumns(MaterializeDataType type) { + if (columns == null) { + return Collections.emptyList(); + } else { + return columns.stream().filter(c -> c.getType() == type).collect(Collectors.toList()); + } + } + + public MaterializeExpression generateExpressionWithExpectedResult(MaterializeDataType type) { + this.expectedResult = true; + MaterializeExpressionGenerator gen = new MaterializeExpressionGenerator(globalState).setColumns(columns) + .setRowValue(rw); + MaterializeExpression expr; + do { + expr = gen.generateExpression(type); + } while (expr.getExpectedValue() == null); + return expr; + } + + public static MaterializeExpression generateConstant(Randomly r, MaterializeDataType type) { + if (Randomly.getBooleanWithSmallProbability()) { + return MaterializeConstant.createNullConstant(); + } + switch (type) { + case INT: + if (Randomly.getBooleanWithSmallProbability()) { + return MaterializeConstant.createTextConstant(String.valueOf(r.getInteger())); + } else { + return MaterializeConstant.createIntConstant(r.getInteger()); + } + case BOOLEAN: + if (Randomly.getBooleanWithSmallProbability() && !MaterializeProvider.generateOnlyKnown) { + return MaterializeConstant + .createTextConstant(Randomly.fromOptions("TR", "TRUE", "FA", "FALSE", "0", "1", "ON", "off")); + } else { + return MaterializeConstant.createBooleanConstant(Randomly.getBoolean()); + } + case TEXT: + return MaterializeConstant.createTextConstant(r.getString()); + case DECIMAL: + return MaterializeConstant.createDecimalConstant(r.getRandomBigDecimal()); + case FLOAT: + return MaterializeConstant.createFloatConstant((float) r.getDouble()); + case REAL: + return MaterializeConstant.createDoubleConstant(r.getDouble()); + case BIT: + return MaterializeConstant.createBitConstant(r.getInteger()); + default: + throw new AssertionError(type); + } + } + + public static MaterializeExpression generateExpression(MaterializeGlobalState globalState, + List columns, MaterializeDataType type) { + return new MaterializeExpressionGenerator(globalState).setColumns(columns).generateExpression(0, type); + } + + public static MaterializeExpression generateExpression(MaterializeGlobalState globalState, + List columns) { + return new MaterializeExpressionGenerator(globalState).setColumns(columns).generateExpression(0); + + } + + public List generateExpressions(int nr) { + List expressions = new ArrayList<>(); + for (int i = 0; i < nr; i++) { + expressions.add(generateExpression(0)); + } + return expressions; + } + + public MaterializeExpression generateExpression(MaterializeDataType dataType) { + return generateExpression(0, dataType); + } + + public MaterializeExpressionGenerator setGlobalState(MaterializeGlobalState globalState) { + this.globalState = globalState; + return this; + } + + public MaterializeExpression generateHavingClause() { + this.allowAggregateFunctions = true; + MaterializeExpression expression = generateExpression(MaterializeDataType.BOOLEAN); + this.allowAggregateFunctions = false; + return expression; + } + + public MaterializeExpression generateAggregate() { + return getAggregate(MaterializeDataType.getRandomType()); + } + + private MaterializeExpression getAggregate(MaterializeDataType dataType) { + List aggregates = MaterializeAggregateFunction.getAggregates(dataType); + MaterializeAggregateFunction agg = Randomly.fromList(aggregates); + return generateArgsForAggregate(dataType, agg); + } + + public MaterializeAggregate generateArgsForAggregate(MaterializeDataType dataType, + MaterializeAggregateFunction agg) { + List types = agg.getTypes(dataType); + List args = new ArrayList<>(); + for (MaterializeDataType argType : types) { + args.add(generateExpression(argType)); + } + return new MaterializeAggregate(args, agg); + } + + public MaterializeExpressionGenerator allowAggregates(boolean value) { + allowAggregateFunctions = value; + return this; + } + + @Override + public MaterializeExpression generatePredicate() { + return generateExpression(MaterializeDataType.BOOLEAN); + } + + @Override + public MaterializeExpression negatePredicate(MaterializeExpression predicate) { + return new MaterializePrefixOperation(predicate, MaterializePrefixOperation.PrefixOperator.NOT); + } + + @Override + public MaterializeExpression isNull(MaterializeExpression expr) { + return new MaterializePostfixOperation(expr, PostfixOperator.IS_NULL); + } + + @Override + public MaterializeExpressionGenerator setTablesAndColumns( + AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public MaterializeExpression generateBooleanExpression() { + return generateExpression(MaterializeDataType.BOOLEAN); + } + + @Override + public MaterializeSelect generateSelect() { + return new MaterializeSelect(); + } + + @Override + public List getRandomJoinClauses() { + List joinStatements = new ArrayList<>(); + MaterializeExpressionGenerator gen = new MaterializeExpressionGenerator(globalState).setColumns(columns); + for (int i = 1; i < tables.size(); i++) { + MaterializeExpression joinClause = gen.generateExpression(MaterializeDataType.BOOLEAN); + MaterializeTable table = Randomly.fromList(tables); + tables.remove(table); + MaterializeJoinType options = MaterializeJoinType.getRandom(); + MaterializeJoin j = new MaterializeJoin(new MaterializeFromTable(table, Randomly.getBoolean()), joinClause, + options); + joinStatements.add(j); + } + // JOIN subqueries + for (int i = 0; i < Randomly.smallNumber(); i++) { + MaterializeTables subqueryTables = globalState.getSchema().getRandomTableNonEmptyTables(); + MaterializeSubquery subquery = MaterializeTLPBase.createSubquery(globalState, String.format("sub%d", i), + subqueryTables); + MaterializeExpression joinClause = gen.generateExpression(MaterializeDataType.BOOLEAN); + MaterializeJoinType options = MaterializeJoinType.getRandom(); + MaterializeJoin j = new MaterializeJoin(subquery, joinClause, options); + joinStatements.add(j); + } + + return joinStatements; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new MaterializeFromTable(t, Randomly.getBoolean())) + .collect(Collectors.toList()); + } + + @Override + public String generateOptimizedQueryString(MaterializeSelect select, MaterializeExpression whereCondition, + boolean shouldUseAggregate) { + if (shouldUseAggregate) { + MaterializeAggregate aggr = new MaterializeAggregate( + List.of(new MaterializeColumnValue(MaterializeColumn.createDummy("*"), null)), + MaterializeAggregateFunction.COUNT); + select.setFetchColumns(List.of(aggr)); + } else { + MaterializeColumnValue allColumns = new MaterializeColumnValue(Randomly.fromList(columns), null); + select.setFetchColumns(List.of(allColumns)); + if (Randomly.getBooleanWithSmallProbability()) { + select.setOrderByClauses(generateOrderBys()); + } + select.setSelectType(SelectType.ALL); + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(MaterializeSelect select, MaterializeExpression whereCondition) { + MaterializeCastOperation isTrue = new MaterializeCastOperation(whereCondition, + MaterializeCompoundDataType.create(MaterializeDataType.INT)); + MaterializePostfixText asText = new MaterializePostfixText(isTrue, " as count", null, MaterializeDataType.INT); + select.setFetchColumns(List.of(asText)); + select.setSelectType(SelectType.ALL); + select.setWhereClause(null); + + return "SELECT SUM(count) FROM (" + select.asString() + ") as res"; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy) { + return List.of(new MaterializeColumnValue(MaterializeColumn.createDummy("*"), null)); + } + List fetchColumns = new ArrayList<>(); + List targetColumns = Randomly.nonEmptySubset(columns); + for (MaterializeColumn c : targetColumns) { + fetchColumns.add(new MaterializeColumnValue(c, null)); + } + return fetchColumns; + } +} diff --git a/src/sqlancer/materialize/gen/MaterializeIndexGenerator.java b/src/sqlancer/materialize/gen/MaterializeIndexGenerator.java new file mode 100644 index 000000000..9d7a91b50 --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeIndexGenerator.java @@ -0,0 +1,80 @@ +package sqlancer.materialize.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; + +public final class MaterializeIndexGenerator { + + private MaterializeIndexGenerator() { + } + + public enum IndexType { + BTREE, HASH, GIST, GIN + } + + public static SQLQueryAdapter generate(MaterializeGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + sb.append("CREATE"); + sb.append(" INDEX "); + MaterializeTable randomTable = globalState.getSchema().getRandomTable(t -> !t.isView()); // TODO: materialized + // views + sb.append(MaterializeCommon.getFreeIndexName(globalState.getSchema())); + sb.append(" ON "); + sb.append(randomTable.getName()); + IndexType method; + method = IndexType.BTREE; + + sb.append("("); + if (method == IndexType.HASH) { + sb.append(randomTable.getRandomColumn().getName()); + } else { + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(randomTable.getRandomColumn().getName()); + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("ASC", "DESC")); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" NULLS "); + sb.append(Randomly.fromOptions("FIRST", "LAST")); + } + } + } + + sb.append(")"); + errors.add("already contains data"); // CONCURRENT INDEX failed + errors.add("You might need to add explicit type casts"); + errors.add(" collations are not supported"); + errors.add("because it has pending trigger events"); + errors.add("could not determine which collation to use for index expression"); + errors.add("could not determine which collation to use for string comparison"); + errors.add("is duplicated"); + errors.add("access method \"gin\" does not support unique indexes"); + errors.add("access method \"hash\" does not support unique indexes"); + errors.add("already exists"); + errors.add("could not create unique index"); + errors.add("has no default operator class"); + errors.add("does not support"); + errors.add("does not support casting"); + errors.add("unsupported UNIQUE constraint with partition key definition"); + errors.add("insufficient columns in UNIQUE constraint definition"); + errors.add("invalid input syntax for"); + errors.add("must be type "); + errors.add("integer out of range"); + errors.add("division by zero"); + errors.add("out of range"); + errors.add("functions in index predicate must be marked IMMUTABLE"); + errors.add("functions in index expression must be marked IMMUTABLE"); + errors.add("result of range difference would not be contiguous"); + errors.add("which is part of the partition key"); + MaterializeCommon.addCommonExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } +} diff --git a/src/sqlancer/materialize/gen/MaterializeInsertGenerator.java b/src/sqlancer/materialize/gen/MaterializeInsertGenerator.java new file mode 100644 index 000000000..7a5374b95 --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeInsertGenerator.java @@ -0,0 +1,106 @@ +package sqlancer.materialize.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; +import sqlancer.materialize.MaterializeVisitor; +import sqlancer.materialize.ast.MaterializeExpression; + +public final class MaterializeInsertGenerator { + + private MaterializeInsertGenerator() { + } + + public static SQLQueryAdapter insert(MaterializeGlobalState globalState) { + MaterializeTable table = globalState.getSchema().getRandomTable(t -> t.isInsertable()); + ExpectedErrors errors = new ExpectedErrors(); + errors.add("cannot insert into column"); + MaterializeCommon.addCommonExpressionErrors(errors); + MaterializeCommon.addCommonInsertUpdateErrors(errors); + MaterializeCommon.addCommonExpressionErrors(errors); + errors.add("multiple assignments to same column"); + errors.add("violates foreign key constraint"); + errors.add("value too long for type character"); + errors.add("conflicting key value violates exclusion constraint"); + errors.add("violates not-null constraint"); + errors.add("current transaction is aborted"); + errors.add("bit string too long"); + errors.add("new row violates check option for view"); + errors.add("reached maximum value of sequence"); + errors.add("but expression is of type"); + StringBuilder sb = new StringBuilder(); + sb.append("INSERT INTO "); + sb.append(table.getName()); + List columns = table.getRandomNonEmptyColumnSubset(); + sb.append("("); + sb.append(columns.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + sb.append(")"); + sb.append(" VALUES"); + + if (globalState.getDbmsSpecificOptions().allowBulkInsert && Randomly.getBooleanWithSmallProbability()) { + StringBuilder sbRowValue = new StringBuilder(); + sbRowValue.append("("); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sbRowValue.append(", "); + } + sbRowValue.append(MaterializeVisitor.asString(MaterializeExpressionGenerator + .generateConstant(globalState.getRandomly(), columns.get(i).getType()))); + } + sbRowValue.append(")"); + + int n = (int) Randomly.getNotCachedInteger(100, 1000); + for (int i = 0; i < n; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(sbRowValue); + } + } else { + int n = Randomly.smallNumber() + 1; + for (int i = 0; i < n; i++) { + if (i != 0) { + sb.append(", "); + } + insertRow(globalState, sb, columns); + } + } + errors.add("duplicate key value violates unique constraint"); + errors.add("identity column defined as GENERATED ALWAYS"); + errors.add("out of range"); + errors.add("violates check constraint"); + errors.add("no partition of relation"); + errors.add("invalid input syntax"); + errors.add("division by zero"); + errors.add("violates foreign key constraint"); + errors.add("data type unknown"); + return new SQLQueryAdapter(sb.toString(), errors); + } + + private static void insertRow(MaterializeGlobalState globalState, StringBuilder sb, + List columns) { + sb.append("("); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sb.append(", "); + } + MaterializeExpression generateConstant; + if (Randomly.getBoolean()) { + generateConstant = MaterializeExpressionGenerator.generateConstant(globalState.getRandomly(), + columns.get(i).getType()); + } else { + generateConstant = new MaterializeExpressionGenerator(globalState) + .generateExpression(columns.get(i).getType()); + } + sb.append(MaterializeVisitor.asString(generateConstant)); + } + sb.append(")"); + } + +} diff --git a/src/sqlancer/materialize/gen/MaterializeRandomQueryGenerator.java b/src/sqlancer/materialize/gen/MaterializeRandomQueryGenerator.java new file mode 100644 index 000000000..05d908e05 --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeRandomQueryGenerator.java @@ -0,0 +1,64 @@ +package sqlancer.materialize.gen; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeTables; +import sqlancer.materialize.ast.MaterializeConstant; +import sqlancer.materialize.ast.MaterializeExpression; +import sqlancer.materialize.ast.MaterializeSelect; +import sqlancer.materialize.ast.MaterializeSelect.ForClause; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeFromTable; +import sqlancer.materialize.ast.MaterializeSelect.SelectType; + +public final class MaterializeRandomQueryGenerator { + + private MaterializeRandomQueryGenerator() { + } + + public static MaterializeSelect createRandomQuery(int nrColumns, MaterializeGlobalState globalState) { + List columns = new ArrayList<>(); + MaterializeTables tables = globalState.getSchema().getRandomTableNonEmptyTables(); + MaterializeExpressionGenerator gen = new MaterializeExpressionGenerator(globalState) + .setColumns(tables.getColumns()); + for (int i = 0; i < nrColumns; i++) { + columns.add(gen.generateExpression(0)); + } + MaterializeSelect select = new MaterializeSelect(); + select.setSelectType(SelectType.getRandom()); + if (select.getSelectOption() == SelectType.DISTINCT && Randomly.getBoolean()) { + select.setDistinctOnClause(gen.generateExpression(0)); + } + select.setFromList(tables.getTables().stream().map(t -> new MaterializeFromTable(t, Randomly.getBoolean())) + .collect(Collectors.toList())); + select.setFetchColumns(columns); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(0, MaterializeDataType.BOOLEAN)); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + if (Randomly.getBoolean()) { + select.setHavingClause(gen.generateHavingClause()); + } + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + if (Randomly.getBoolean()) { + select.setLimitClause(MaterializeConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + if (Randomly.getBoolean()) { + select.setOffsetClause( + MaterializeConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + } + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setForClause(ForClause.getRandom()); + } + return select; + } + +} diff --git a/src/sqlancer/materialize/gen/MaterializeTableGenerator.java b/src/sqlancer/materialize/gen/MaterializeTableGenerator.java new file mode 100644 index 000000000..c6772db47 --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeTableGenerator.java @@ -0,0 +1,128 @@ +package sqlancer.materialize.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; +import sqlancer.materialize.MaterializeVisitor; + +public class MaterializeTableGenerator { + + private final String tableName; + private final StringBuilder sb = new StringBuilder(); + @SuppressWarnings("unused") + private boolean isTemporaryTable; + @SuppressWarnings("unused") + private final MaterializeSchema newSchema; + private final List columnsToBeAdded = new ArrayList<>(); + protected final ExpectedErrors errors = new ExpectedErrors(); + private final MaterializeTable table; + private final boolean generateOnlyKnown; + private final MaterializeGlobalState globalState; + + public MaterializeTableGenerator(String tableName, MaterializeSchema newSchema, boolean generateOnlyKnown, + MaterializeGlobalState globalState) { + this.tableName = tableName; + this.newSchema = newSchema; + this.generateOnlyKnown = generateOnlyKnown; + this.globalState = globalState; + table = new MaterializeTable(tableName, columnsToBeAdded, null, null, null, false, false); + errors.add("invalid input syntax for"); + errors.add("is not unique"); + errors.add("integer out of range"); + errors.add("division by zero"); + errors.add("cannot create partitioned table as inheritance child"); + errors.add("does not support casting"); + errors.add("ERROR: functions in index expression must be marked IMMUTABLE"); + errors.add("functions in partition key expression must be marked IMMUTABLE"); + errors.add("functions in index predicate must be marked IMMUTABLE"); + errors.add("has no default operator class for access method"); + errors.add("does not exist for access method"); + errors.add("does not accept data type"); + errors.add("but default expression is of type text"); + errors.add("has pseudo-type unknown"); + errors.add("no collation was derived for partition key column"); + errors.add("inherits from generated column but specifies identity"); + errors.add("inherits from generated column but specifies default"); + errors.add("already exists"); + MaterializeCommon.addCommonExpressionErrors(errors); + MaterializeCommon.addCommonTableErrors(errors); + } + + public static SQLQueryAdapter generate(String tableName, MaterializeSchema newSchema, boolean generateOnlyKnown, + MaterializeGlobalState globalState) { + return new MaterializeTableGenerator(tableName, newSchema, generateOnlyKnown, globalState).generate(); + } + + protected SQLQueryAdapter generate() { + sb.append("CREATE"); + sb.append(" TABLE"); + if (Randomly.getBoolean()) { + sb.append(" IF NOT EXISTS"); + } + sb.append(" "); + sb.append(tableName); + createStandard(); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + + private void createStandard() throws AssertionError { + sb.append("("); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + String name = DBMSCommon.createColumnName(i); + createColumn(name); + } + sb.append(")"); + } + + private void createColumn(String name) throws AssertionError { + sb.append(name); + sb.append(" "); + MaterializeDataType type = MaterializeDataType.getRandomType(); + MaterializeCommon.appendDataType(type, sb, true, generateOnlyKnown, globalState.getCollates()); + MaterializeColumn c = new MaterializeColumn(name, type); + c.setTable(table); + columnsToBeAdded.add(c); + sb.append(" "); + if (Randomly.getBoolean()) { + createColumnConstraint(type); + } + } + + private enum ColumnConstraint { + DEFAULT + }; + + private void createColumnConstraint(MaterializeDataType type) { + List constraintSubset = Randomly.nonEmptySubset(ColumnConstraint.values()); + for (ColumnConstraint c : constraintSubset) { + sb.append(" "); + switch (c) { + case DEFAULT: + sb.append("DEFAULT"); + sb.append(" ("); + sb.append(MaterializeVisitor + .asString(MaterializeExpressionGenerator.generateExpression(globalState, type))); + sb.append(")"); + // CREATE TEMPORARY TABLE t1(c0 smallint DEFAULT ('566963878')); + errors.add("out of range"); + errors.add("is a generated column"); + break; + default: + throw new AssertionError(sb); + } + } + } + +} diff --git a/src/sqlancer/materialize/gen/MaterializeUpdateGenerator.java b/src/sqlancer/materialize/gen/MaterializeUpdateGenerator.java new file mode 100644 index 000000000..1b7e69208 --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeUpdateGenerator.java @@ -0,0 +1,78 @@ +package sqlancer.materialize.gen; + +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.gen.AbstractUpdateGenerator; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; +import sqlancer.materialize.MaterializeVisitor; +import sqlancer.materialize.ast.MaterializeExpression; + +public final class MaterializeUpdateGenerator extends AbstractUpdateGenerator { + + private final MaterializeGlobalState globalState; + private MaterializeTable randomTable; + + private MaterializeUpdateGenerator(MaterializeGlobalState globalState) { + this.globalState = globalState; + errors.addAll(Arrays.asList("conflicting key value violates exclusion constraint", + "reached maximum value of sequence", "violates foreign key constraint", "violates not-null constraint", + "violates unique constraint", "out of range", "does not support casting", "must be type boolean", + "is not unique", " bit string too long", "can only be updated to DEFAULT", "division by zero", + "You might need to add explicit type casts.", "invalid regular expression", + "View columns that are not columns of their base relation are not updatable")); + } + + public static SQLQueryAdapter create(MaterializeGlobalState globalState) { + return new MaterializeUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + randomTable = globalState.getSchema().getRandomTable(t -> t.isInsertable()); + List columns = randomTable.getRandomNonEmptyColumnSubset(); + sb.append("UPDATE "); + sb.append(randomTable.getName()); + sb.append(" SET "); + errors.add("multiple assignments to same column"); // view whose columns refer to a column in the referenced + // table multiple times + errors.add("new row violates check option for view"); + MaterializeCommon.addCommonInsertUpdateErrors(errors); + updateColumns(columns); + errors.add("invalid input syntax for "); + errors.add("operator does not exist: text = boolean"); + errors.add("violates check constraint"); + errors.add("could not determine which collation to use for string comparison"); + errors.add("but expression is of type"); + MaterializeCommon.addCommonExpressionErrors(errors); + if (!Randomly.getBooleanWithSmallProbability()) { + sb.append(" WHERE "); + MaterializeExpression where = MaterializeExpressionGenerator.generateExpression(globalState, + randomTable.getColumns(), MaterializeDataType.BOOLEAN); + sb.append(MaterializeVisitor.asString(where)); + } + + return new SQLQueryAdapter(sb.toString(), errors, true); + } + + @Override + protected void updateValue(MaterializeColumn column) { + if (!Randomly.getBoolean()) { + MaterializeExpression constant = MaterializeExpressionGenerator.generateConstant(globalState.getRandomly(), + column.getType()); + sb.append(MaterializeVisitor.asString(constant)); + } else { + sb.append("("); + MaterializeExpression expr = MaterializeExpressionGenerator.generateExpression(globalState, + randomTable.getColumns(), column.getType()); + // caused by casts + sb.append(MaterializeVisitor.asString(expr)); + sb.append(")"); + } + } + +} diff --git a/src/sqlancer/materialize/gen/MaterializeViewGenerator.java b/src/sqlancer/materialize/gen/MaterializeViewGenerator.java new file mode 100644 index 000000000..e3cb8ff29 --- /dev/null +++ b/src/sqlancer/materialize/gen/MaterializeViewGenerator.java @@ -0,0 +1,65 @@ +package sqlancer.materialize.gen; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeVisitor; +import sqlancer.materialize.ast.MaterializeSelect; + +public final class MaterializeViewGenerator { + + private MaterializeViewGenerator() { + } + + public static SQLQueryAdapter create(MaterializeGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder("CREATE"); + @SuppressWarnings("unused") + boolean materialized; + @SuppressWarnings("unused") + boolean recursive = false; + if (Randomly.getBoolean()) { + sb.append(" MATERIALIZED"); + materialized = true; + } else { + if (Randomly.getBoolean()) { + sb.append(" OR REPLACE"); + } + materialized = false; + } + sb.append(" VIEW "); + String name = globalState.getSchema().getFreeViewName(); + sb.append(name); + sb.append("("); + int nrColumns = Randomly.smallNumber() + 1; + for (int i = 0; i < nrColumns; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(DBMSCommon.createColumnName(i)); + } + sb.append(")"); + sb.append(" AS ("); + MaterializeSelect select = MaterializeRandomQueryGenerator.createRandomQuery(nrColumns, globalState); + sb.append(MaterializeVisitor.asString(select)); + sb.append(")"); + MaterializeCommon.addGroupingErrors(errors); + errors.add("already exists"); + errors.add("cannot drop columns from view"); + errors.add("non-integer constant in ORDER BY"); // TODO + errors.add("for SELECT DISTINCT, ORDER BY expressions must appear in select list"); // TODO + errors.add("cannot change data type of view column"); + errors.add("specified more than once"); // TODO + errors.add("materialized views must not use temporary tables or views"); + errors.add("does not have the form non-recursive-term UNION [ALL] recursive-term"); + errors.add("is not a view"); + errors.add("non-integer constant in DISTINCT ON"); + errors.add("unable to parse column reference in DISTINCT ON clause"); + errors.add("SELECT DISTINCT ON expressions must match initial ORDER BY expressions"); + MaterializeCommon.addCommonExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + +} diff --git a/src/sqlancer/materialize/oracle/MaterializePivotedQuerySynthesisOracle.java b/src/sqlancer/materialize/oracle/MaterializePivotedQuerySynthesisOracle.java new file mode 100644 index 000000000..50dff2b49 --- /dev/null +++ b/src/sqlancer/materialize/oracle/MaterializePivotedQuerySynthesisOracle.java @@ -0,0 +1,151 @@ +package sqlancer.materialize.oracle; + +import java.sql.SQLException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.oracle.PivotedQuerySynthesisBase; +import sqlancer.common.query.Query; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeRowValue; +import sqlancer.materialize.MaterializeSchema.MaterializeTables; +import sqlancer.materialize.MaterializeVisitor; +import sqlancer.materialize.ast.MaterializeColumnValue; +import sqlancer.materialize.ast.MaterializeConstant; +import sqlancer.materialize.ast.MaterializeExpression; +import sqlancer.materialize.ast.MaterializePostfixOperation; +import sqlancer.materialize.ast.MaterializePostfixOperation.PostfixOperator; +import sqlancer.materialize.ast.MaterializeSelect; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeFromTable; +import sqlancer.materialize.gen.MaterializeCommon; +import sqlancer.materialize.gen.MaterializeExpressionGenerator; + +public class MaterializePivotedQuerySynthesisOracle extends + PivotedQuerySynthesisBase { + + private List fetchColumns; + + public MaterializePivotedQuerySynthesisOracle(MaterializeGlobalState globalState) throws SQLException { + super(globalState); + MaterializeCommon.addCommonExpressionErrors(errors); + MaterializeCommon.addCommonFetchErrors(errors); + } + + @Override + public SQLQueryAdapter getRectifiedQuery() throws SQLException { + MaterializeTables randomFromTables = globalState.getSchema().getRandomTableNonEmptyTables(); + + MaterializeSelect selectStatement = new MaterializeSelect(); + selectStatement.setSelectType(Randomly.fromOptions(MaterializeSelect.SelectType.values())); + List columns = randomFromTables.getColumns(); + pivotRow = randomFromTables.getRandomRowValue(globalState.getConnection()); + + fetchColumns = columns; + selectStatement.setFromList(randomFromTables.getTables().stream().map(t -> new MaterializeFromTable(t, false)) + .collect(Collectors.toList())); + selectStatement.setFetchColumns(fetchColumns.stream() + .map(c -> new MaterializeColumnValue(getFetchValueAliasedColumn(c), pivotRow.getValues().get(c))) + .collect(Collectors.toList())); + MaterializeExpression whereClause = generateRectifiedExpression(columns, pivotRow); + selectStatement.setWhereClause(whereClause); + List groupByClause = generateGroupByClause(columns, pivotRow); + selectStatement.setGroupByExpressions(groupByClause); + MaterializeExpression limitClause = generateLimit(); + selectStatement.setLimitClause(limitClause); + if (limitClause != null) { + MaterializeExpression offsetClause = generateOffset(); + selectStatement.setOffsetClause(offsetClause); + } + List orderBy = new MaterializeExpressionGenerator(globalState).setColumns(columns) + .generateOrderBys(); + selectStatement.setOrderByClauses(orderBy); + return new SQLQueryAdapter(MaterializeVisitor.asString(selectStatement)); + } + + /* + * Prevent name collisions by aliasing the column. + */ + private MaterializeColumn getFetchValueAliasedColumn(MaterializeColumn c) { + MaterializeColumn aliasedColumn = new MaterializeColumn( + c.getName() + " AS " + c.getTable().getName() + c.getName(), c.getType()); + aliasedColumn.setTable(c.getTable()); + return aliasedColumn; + } + + private List generateGroupByClause(List columns, MaterializeRowValue rw) { + if (Randomly.getBoolean()) { + return columns.stream().map(c -> MaterializeColumnValue.create(c, rw.getValues().get(c))) + .collect(Collectors.toList()); + } else { + return Collections.emptyList(); + } + } + + private MaterializeConstant generateLimit() { + if (Randomly.getBoolean()) { + return MaterializeConstant.createIntConstant(Integer.MAX_VALUE); + } else { + return null; + } + } + + private MaterializeExpression generateOffset() { + if (Randomly.getBoolean()) { + return MaterializeConstant.createIntConstant(0); + } else { + return null; + } + } + + private MaterializeExpression generateRectifiedExpression(List columns, MaterializeRowValue rw) { + MaterializeExpression expr = new MaterializeExpressionGenerator(globalState).setColumns(columns).setRowValue(rw) + .generateExpressionWithExpectedResult(MaterializeDataType.BOOLEAN); + MaterializeExpression result; + if (expr.getExpectedValue().isNull()) { + result = MaterializePostfixOperation.create(expr, PostfixOperator.IS_NULL); + } else { + result = MaterializePostfixOperation.create(expr, + expr.getExpectedValue().cast(MaterializeDataType.BOOLEAN).asBoolean() ? PostfixOperator.IS_TRUE + : PostfixOperator.IS_FALSE); + } + rectifiedPredicates.add(result); + return result; + } + + @Override + protected Query getContainmentCheckQuery(Query query) throws SQLException { + StringBuilder sb = new StringBuilder(); + sb.append("SELECT * FROM ("); // ANOTHER SELECT TO USE ORDER BY without restrictions + sb.append(query.getUnterminatedQueryString()); + sb.append(") as result WHERE "); + int i = 0; + for (MaterializeColumn c : fetchColumns) { + if (i++ != 0) { + sb.append(" AND "); + } + sb.append("result."); + sb.append(c.getTable().getName()); + sb.append(c.getName()); + if (pivotRow.getValues().get(c).isNull()) { + sb.append(" IS NULL"); + } else { + sb.append(" = "); + sb.append(pivotRow.getValues().get(c).getTextRepresentation()); + } + } + String resultingQueryString = sb.toString(); + return new SQLQueryAdapter(resultingQueryString, errors); + } + + @Override + protected String getExpectedValues(MaterializeExpression expr) { + return MaterializeVisitor.asExpectedValues(expr); + } + +} diff --git a/src/sqlancer/materialize/oracle/tlp/MaterializeTLPAggregateOracle.java b/src/sqlancer/materialize/oracle/tlp/MaterializeTLPAggregateOracle.java new file mode 100644 index 000000000..6bbd7d795 --- /dev/null +++ b/src/sqlancer/materialize/oracle/tlp/MaterializeTLPAggregateOracle.java @@ -0,0 +1,187 @@ +package sqlancer.materialize.oracle.tlp; + +import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.postgresql.util.PSQLException; + +import sqlancer.ComparatorHelper; +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeVisitor; +import sqlancer.materialize.ast.MaterializeAggregate; +import sqlancer.materialize.ast.MaterializeAggregate.MaterializeAggregateFunction; +import sqlancer.materialize.ast.MaterializeAlias; +import sqlancer.materialize.ast.MaterializeExpression; +import sqlancer.materialize.ast.MaterializeJoin; +import sqlancer.materialize.ast.MaterializePostfixOperation; +import sqlancer.materialize.ast.MaterializePostfixOperation.PostfixOperator; +import sqlancer.materialize.ast.MaterializePrefixOperation; +import sqlancer.materialize.ast.MaterializePrefixOperation.PrefixOperator; +import sqlancer.materialize.ast.MaterializeSelect; +import sqlancer.materialize.gen.MaterializeCommon; + +public class MaterializeTLPAggregateOracle extends MaterializeTLPBase implements TestOracle { + private String generatedQueryString; + + private String firstResult; + private String secondResult; + private String originalQuery; + private String metamorphicQuery; + + public MaterializeTLPAggregateOracle(MaterializeGlobalState state) { + super(state); + MaterializeCommon.addGroupingErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + aggregateCheck(); + } + + protected void aggregateCheck() throws SQLException { + MaterializeAggregateFunction aggregateFunction = Randomly.fromOptions(MaterializeAggregateFunction.MAX, + MaterializeAggregateFunction.MIN, MaterializeAggregateFunction.SUM, + MaterializeAggregateFunction.BIT_AND, MaterializeAggregateFunction.BIT_OR, + MaterializeAggregateFunction.BOOL_AND, MaterializeAggregateFunction.BOOL_OR, + MaterializeAggregateFunction.COUNT); + MaterializeAggregate aggregate = gen.generateArgsForAggregate(aggregateFunction.getRandomReturnType(), + aggregateFunction); + List fetchColumns = new ArrayList<>(); + fetchColumns.add(aggregate); + while (Randomly.getBooleanWithRatherLowProbability()) { + fetchColumns.add(gen.generateAggregate()); + } + select.setFetchColumns(Arrays.asList(aggregate)); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + originalQuery = MaterializeVisitor.asString(select); + generatedQueryString = originalQuery; + firstResult = getAggregateResult(originalQuery); + metamorphicQuery = createMetamorphicUnionQuery(select, aggregate, select.getFromList()); + secondResult = getAggregateResult(metamorphicQuery); + + String queryFormatString = "-- %s;\n-- result: %s"; + String firstQueryString = String.format(queryFormatString, originalQuery, firstResult); + String secondQueryString = String.format(queryFormatString, metamorphicQuery, secondResult); + state.getState().getLocalState().log(String.format("%s\n%s", firstQueryString, secondQueryString)); + if (firstResult == null && secondResult != null || firstResult != null && secondResult == null + || firstResult != null && !firstResult.contentEquals(secondResult) + && !ComparatorHelper.isEqualDouble(firstResult, secondResult)) { + if (secondResult != null && secondResult.contains("Inf")) { + throw new IgnoreMeException(); // FIXME: average computation + } + String assertionMessage = String.format("the results mismatch!\n%s\n%s", firstQueryString, + secondQueryString); + throw new AssertionError(assertionMessage); + } + } + + private String createMetamorphicUnionQuery(MaterializeSelect select, MaterializeAggregate aggregate, + List from) { + String metamorphicQuery; + MaterializeExpression whereClause = gen.generateExpression(MaterializeDataType.BOOLEAN); + MaterializeExpression negatedClause = new MaterializePrefixOperation(whereClause, PrefixOperator.NOT); + MaterializeExpression notNullClause = new MaterializePostfixOperation(whereClause, PostfixOperator.IS_NULL); + List mappedAggregate = mapped(aggregate); + MaterializeSelect leftSelect = getSelect(mappedAggregate, from, whereClause, select.getJoinClauses()); + MaterializeSelect middleSelect = getSelect(mappedAggregate, from, negatedClause, select.getJoinClauses()); + MaterializeSelect rightSelect = getSelect(mappedAggregate, from, notNullClause, select.getJoinClauses()); + metamorphicQuery = "SELECT " + getOuterAggregateFunction(aggregate) + " FROM ("; + metamorphicQuery += MaterializeVisitor.asString(leftSelect) + " UNION ALL " + + MaterializeVisitor.asString(middleSelect) + " UNION ALL " + MaterializeVisitor.asString(rightSelect); + metamorphicQuery += ") as asdf"; + return metamorphicQuery; + } + + private String getAggregateResult(String queryString) throws SQLException { + // log TLP Aggregate SELECT queries on the current log file + if (state.getOptions().logEachSelect()) { + // TODO: refactor me + state.getLogger().writeCurrent(queryString); + try { + state.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + String resultString; + SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors); + try (SQLancerResultSet result = q.executeAndGet(state)) { + if (result == null) { + throw new IgnoreMeException(); + } + if (!result.next()) { + resultString = null; + } else { + resultString = result.getString(1); + } + } catch (PSQLException e) { + throw new AssertionError(queryString, e); + } + return resultString; + } + + private List mapped(MaterializeAggregate aggregate) { + switch (aggregate.getFunction()) { + case SUM: + case COUNT: + case BIT_AND: + case BIT_OR: + case BOOL_AND: + case BOOL_OR: + case MAX: + case MIN: + return aliasArgs(Arrays.asList(aggregate)); + default: + throw new AssertionError(aggregate.getFunction()); + } + } + + private List aliasArgs(List originalAggregateArgs) { + List args = new ArrayList<>(); + int i = 0; + for (MaterializeExpression expr : originalAggregateArgs) { + args.add(new MaterializeAlias(expr, "agg" + i++)); + } + return args; + } + + private String getOuterAggregateFunction(MaterializeAggregate aggregate) { + switch (aggregate.getFunction()) { + case COUNT: + return MaterializeAggregateFunction.SUM.toString() + "(agg0)"; + default: + return aggregate.getFunction().toString() + "(agg0)"; + } + } + + private MaterializeSelect getSelect(List aggregates, List from, + MaterializeExpression whereClause, List joinList) { + MaterializeSelect leftSelect = new MaterializeSelect(); + leftSelect.setFetchColumns(aggregates); + leftSelect.setFromList(from); + leftSelect.setWhereClause(whereClause); + leftSelect.setJoinClauses(joinList); + if (Randomly.getBooleanWithSmallProbability()) { + leftSelect.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + } + return leftSelect; + } + + @Override + public String getLastQueryString() { + return generatedQueryString; + } +} diff --git a/src/sqlancer/materialize/oracle/tlp/MaterializeTLPBase.java b/src/sqlancer/materialize/oracle/tlp/MaterializeTLPBase.java new file mode 100644 index 000000000..faf136c21 --- /dev/null +++ b/src/sqlancer/materialize/oracle/tlp/MaterializeTLPBase.java @@ -0,0 +1,144 @@ +package sqlancer.materialize.oracle.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema; +import sqlancer.materialize.MaterializeSchema.MaterializeColumn; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeSchema.MaterializeTable; +import sqlancer.materialize.MaterializeSchema.MaterializeTables; +import sqlancer.materialize.ast.MaterializeColumnValue; +import sqlancer.materialize.ast.MaterializeConstant; +import sqlancer.materialize.ast.MaterializeExpression; +import sqlancer.materialize.ast.MaterializeJoin; +import sqlancer.materialize.ast.MaterializeJoin.MaterializeJoinType; +import sqlancer.materialize.ast.MaterializeSelect; +import sqlancer.materialize.ast.MaterializeSelect.ForClause; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeFromTable; +import sqlancer.materialize.ast.MaterializeSelect.MaterializeSubquery; +import sqlancer.materialize.gen.MaterializeCommon; +import sqlancer.materialize.gen.MaterializeExpressionGenerator; + +public class MaterializeTLPBase + extends TernaryLogicPartitioningOracleBase + implements TestOracle { + + protected MaterializeSchema s; + protected MaterializeTables targetTables; + protected MaterializeExpressionGenerator gen; + protected MaterializeSelect select; + + public MaterializeTLPBase(MaterializeGlobalState state) { + super(state); + MaterializeCommon.addCommonExpressionErrors(errors); + MaterializeCommon.addCommonFetchErrors(errors); + } + + @Override + public void check() throws SQLException { + s = state.getSchema(); + targetTables = s.getRandomTableNonEmptyTables(); + List tables = targetTables.getTables(); + List joins = getJoinStatements(state, targetTables.getColumns(), tables); + generateSelectBase(tables, joins); + } + + protected List getJoinStatements(MaterializeGlobalState globalState, + List columns, List tables) { + List joinStatements = new ArrayList<>(); + MaterializeExpressionGenerator gen = new MaterializeExpressionGenerator(globalState).setColumns(columns); + for (int i = 1; i < tables.size(); i++) { + MaterializeExpression joinClause = gen.generateExpression(MaterializeDataType.BOOLEAN); + MaterializeTable table = Randomly.fromList(tables); + tables.remove(table); + MaterializeJoinType options = MaterializeJoinType.getRandom(); + MaterializeJoin j = new MaterializeJoin(new MaterializeFromTable(table, Randomly.getBoolean()), joinClause, + options); + joinStatements.add(j); + } + // JOIN subqueries + for (int i = 0; i < Randomly.smallNumber(); i++) { + MaterializeTables subqueryTables = globalState.getSchema().getRandomTableNonEmptyTables(); + MaterializeSubquery subquery = MaterializeTLPBase.createSubquery(globalState, String.format("sub%d", i), + subqueryTables); + MaterializeExpression joinClause = gen.generateExpression(MaterializeDataType.BOOLEAN); + MaterializeJoinType options = MaterializeJoinType.getRandom(); + MaterializeJoin j = new MaterializeJoin(subquery, joinClause, options); + joinStatements.add(j); + } + + return joinStatements; + } + + protected void generateSelectBase(List tables, List joins) { + List tableList = tables.stream() + .map(t -> new MaterializeFromTable(t, Randomly.getBoolean())).collect(Collectors.toList()); + gen = new MaterializeExpressionGenerator(state).setColumns(targetTables.getColumns()); + initializeTernaryPredicateVariants(); + select = new MaterializeSelect(); + select.setFetchColumns(generateFetchColumns()); + select.setFromList(tableList); + select.setWhereClause(null); + select.setJoinClauses(joins); + if (Randomly.getBoolean()) { + select.setForClause(ForClause.getRandom()); + } + } + + List generateFetchColumns() { + if (Randomly.getBooleanWithRatherLowProbability()) { + return Arrays.asList(new MaterializeColumnValue(MaterializeColumn.createDummy("*"), null)); + } + List fetchColumns = new ArrayList<>(); + List targetColumns = Randomly.nonEmptySubset(targetTables.getColumns()); + for (MaterializeColumn c : targetColumns) { + fetchColumns.add(new MaterializeColumnValue(c, null)); + } + return fetchColumns; + } + + @Override + protected ExpressionGenerator getGen() { + return gen; + } + + public static MaterializeSubquery createSubquery(MaterializeGlobalState globalState, String name, + MaterializeTables tables) { + List columns = new ArrayList<>(); + MaterializeExpressionGenerator gen = new MaterializeExpressionGenerator(globalState) + .setColumns(tables.getColumns()); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + columns.add(gen.generateExpression(0)); + } + MaterializeSelect select = new MaterializeSelect(); + select.setFromList(tables.getTables().stream().map(t -> new MaterializeFromTable(t, Randomly.getBoolean())) + .collect(Collectors.toList())); + select.setFetchColumns(columns); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(0, MaterializeDataType.BOOLEAN)); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + if (Randomly.getBoolean()) { + select.setLimitClause(MaterializeConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + if (Randomly.getBoolean()) { + select.setOffsetClause( + MaterializeConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + } + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setForClause(ForClause.getRandom()); + } + return new MaterializeSubquery(select, name); + } +} diff --git a/src/sqlancer/materialize/oracle/tlp/MaterializeTLPHavingOracle.java b/src/sqlancer/materialize/oracle/tlp/MaterializeTLPHavingOracle.java new file mode 100644 index 000000000..deefb4fb9 --- /dev/null +++ b/src/sqlancer/materialize/oracle/tlp/MaterializeTLPHavingOracle.java @@ -0,0 +1,84 @@ +package sqlancer.materialize.oracle.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeSchema.MaterializeDataType; +import sqlancer.materialize.MaterializeVisitor; +import sqlancer.materialize.ast.MaterializeExpression; +import sqlancer.materialize.gen.MaterializeCommon; + +public class MaterializeTLPHavingOracle extends MaterializeTLPBase { + private String generatedQueryString; + + public MaterializeTLPHavingOracle(MaterializeGlobalState state) { + super(state); + MaterializeCommon.addGroupingErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + havingCheck(); + } + + protected void havingCheck() throws SQLException { + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(MaterializeDataType.BOOLEAN)); + } + select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + select.setHavingClause(null); + String originalQueryString = MaterializeVisitor.asString(select); + generatedQueryString = originalQueryString; + List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + + // See https://github.com/MaterializeInc/materialize/issues/18346, have to check + // if predicate errors by putting + // it in SELECT first + List originalColumns = select.getFetchColumns(); + List checkColumns = new ArrayList<>(); + checkColumns.add(predicate); + select.setFetchColumns(checkColumns); + String errorCheckQueryString = MaterializeVisitor.asString(select); + ComparatorHelper.getResultSetFirstColumnAsString(errorCheckQueryString, errors, state); + select.setFetchColumns(originalColumns); + + boolean orderBy = Randomly.getBoolean(); + if (orderBy) { + select.setOrderByClauses(gen.generateOrderBys()); + } + select.setHavingClause(predicate); + String firstQueryString = MaterializeVisitor.asString(select); + select.setHavingClause(negatedPredicate); + String secondQueryString = MaterializeVisitor.asString(select); + select.setHavingClause(isNullPredicate); + String thirdQueryString = MaterializeVisitor.asString(select); + List combinedString = new ArrayList<>(); + List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, + thirdQueryString, combinedString, !orderBy, state, errors); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state); + } + + @Override + protected MaterializeExpression generatePredicate() { + return gen.generateHavingClause(); + } + + @Override + List generateFetchColumns() { + List expressions = gen.allowAggregates(true) + .generateExpressions(Randomly.smallNumber() + 1); + gen.allowAggregates(false); + return expressions; + } + + @Override + public String getLastQueryString() { + return generatedQueryString; + } +} diff --git a/src/sqlancer/mongodb/MongoDBComparatorHelper.java b/src/sqlancer/mongodb/MongoDBComparatorHelper.java deleted file mode 100644 index 49b692645..000000000 --- a/src/sqlancer/mongodb/MongoDBComparatorHelper.java +++ /dev/null @@ -1,97 +0,0 @@ -package sqlancer.mongodb; - -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import org.bson.Document; - -import sqlancer.IgnoreMeException; -import sqlancer.Main; -import sqlancer.common.query.ExpectedErrors; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.query.MongoDBSelectQuery; - -public final class MongoDBComparatorHelper { - - private MongoDBComparatorHelper() { - } - - public static List getResultSetAsDocumentList(MongoDBSelectQuery adapter, MongoDBGlobalState state) - throws Exception { - ExpectedErrors errors = adapter.getExpectedErrors(); - List result; - try { - adapter.executeAndGet(state); - Main.nrSuccessfulActions.addAndGet(1); - result = adapter.getResultSet(); - return result; - } catch (Exception e) { - if (e instanceof IgnoreMeException) { - throw e; - } - Main.nrUnsuccessfulActions.addAndGet(1); - if (e.getMessage() == null) { - throw new AssertionError(adapter.getLogString(), e); - } - if (errors.errorIsExpected(e.getMessage())) { - throw new IgnoreMeException(); - } - throw new AssertionError(adapter.getLogString(), e); - } - } - - public static void assumeCountIsEqual(List resultSet, List secondResultSet, - MongoDBSelectQuery originalQuery) { - int originalSize = resultSet.size(); - if (secondResultSet.isEmpty()) { - if (originalSize == 0) { - return; - } else { - String assertMessage = String.format("The Count of the result set mismatches!\n %s", - originalQuery.getLogString()); - throw new AssertionError(assertMessage); - } - } - if (secondResultSet.size() != 1) { - throw new AssertionError( - String.format("Count query result bigger than one \n %s", originalQuery.getLogString())); - } - int withCount = (int) secondResultSet.get(0).get("count"); - if (originalSize != withCount) { - String assertMessage = String.format("The Count of the result set mismatches!\n %s", - originalQuery.getLogString()); - throw new AssertionError(assertMessage); - } - } - - public static void assumeResultSetsAreEqual(List resultSet, List secondResultSet, - MongoDBSelectQuery originalQuery) { - if (resultSet.size() != secondResultSet.size()) { - String assertionMessage = String.format("The Size of the result sets mismatch (%d and %d)!\n%s", - resultSet.size(), secondResultSet.size(), originalQuery.getLogString()); - throw new AssertionError(assertionMessage); - } - - Set firstHashSet = new HashSet<>(resultSet); - Set secondHashSet = new HashSet<>(secondResultSet); - - if (!firstHashSet.equals(secondHashSet)) { - Set firstResultSetMisses = new HashSet<>(firstHashSet); - firstResultSetMisses.removeAll(secondHashSet); - Set secondResultSetMisses = new HashSet<>(secondHashSet); - secondResultSetMisses.removeAll(firstHashSet); - StringBuilder firstMisses = new StringBuilder(); - for (Document document : firstResultSetMisses) { - firstMisses.append(document.toJson()).append(" "); - } - StringBuilder secondMisses = new StringBuilder(); - for (Document document : secondResultSetMisses) { - secondMisses.append(document.toJson()).append(" "); - } - String assertMessage = String.format("The Content of the result sets mismatch!\n %s \n %s\n %s", - firstMisses.toString(), secondMisses.toString(), originalQuery.getLogString()); - throw new AssertionError(assertMessage); - } - } -} diff --git a/src/sqlancer/mongodb/MongoDBConnection.java b/src/sqlancer/mongodb/MongoDBConnection.java deleted file mode 100644 index 6971bd79c..000000000 --- a/src/sqlancer/mongodb/MongoDBConnection.java +++ /dev/null @@ -1,35 +0,0 @@ -package sqlancer.mongodb; - -import org.bson.BsonDocument; -import org.bson.BsonString; - -import com.mongodb.client.MongoClient; -import com.mongodb.client.MongoDatabase; - -import sqlancer.SQLancerDBConnection; - -public class MongoDBConnection implements SQLancerDBConnection { - - private final MongoClient client; - private final MongoDatabase database; - - public MongoDBConnection(MongoClient client, MongoDatabase database) { - this.client = client; - this.database = database; - } - - @Override - public String getDatabaseVersion() throws Exception { - return client.getDatabase("dbname").runCommand(new BsonDocument("buildinfo", new BsonString(""))).get("version") - .toString(); - } - - @Override - public void close() throws Exception { - client.close(); - } - - public MongoDatabase getDatabase() { - return database; - } -} diff --git a/src/sqlancer/mongodb/MongoDBLoggableFactory.java b/src/sqlancer/mongodb/MongoDBLoggableFactory.java deleted file mode 100644 index b668301b3..000000000 --- a/src/sqlancer/mongodb/MongoDBLoggableFactory.java +++ /dev/null @@ -1,40 +0,0 @@ -package sqlancer.mongodb; - -import java.util.Arrays; - -import sqlancer.common.log.Loggable; -import sqlancer.common.log.LoggableFactory; -import sqlancer.common.log.LoggedString; -import sqlancer.common.query.Query; - -public class MongoDBLoggableFactory extends LoggableFactory { - @Override - protected Loggable createLoggable(String input, String suffix) { - return new LoggedString(input + suffix); - } - - @Override - public Query getQueryForStateToReproduce(String queryString) { - throw new UnsupportedOperationException(); - } - - @Override - public Query commentOutQuery(Query query) { - throw new UnsupportedOperationException(); - } - - @Override - protected Loggable infoToLoggable(String time, String databaseName, String databaseVersion, long seedValue) { - StringBuilder sb = new StringBuilder(); - sb.append("// Time: ").append(time).append("\n"); - sb.append("// Database: ").append(databaseName).append("\n"); - sb.append("// Database version: ").append(databaseVersion).append("\n"); - sb.append("// seed value: ").append(seedValue).append("\n"); - return new LoggedString(sb.toString()); - } - - @Override - public Loggable convertStacktraceToLoggable(Throwable throwable) { - return new LoggedString(Arrays.toString(throwable.getStackTrace()) + "\n" + throwable.getMessage()); - } -} diff --git a/src/sqlancer/mongodb/MongoDBOptions.java b/src/sqlancer/mongodb/MongoDBOptions.java deleted file mode 100644 index 9e02de67b..000000000 --- a/src/sqlancer/mongodb/MongoDBOptions.java +++ /dev/null @@ -1,74 +0,0 @@ -package sqlancer.mongodb; - -import static sqlancer.mongodb.MongoDBOptions.MongoDBOracleFactory.DOCUMENT_REMOVAL; -import static sqlancer.mongodb.MongoDBOptions.MongoDBOracleFactory.QUERY_PARTITIONING; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import com.beust.jcommander.Parameter; -import com.beust.jcommander.Parameters; - -import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.CompositeTestOracle; -import sqlancer.common.oracle.TestOracle; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.test.MongoDBDocumentRemovalTester; -import sqlancer.mongodb.test.MongoDBQueryPartitioningWhereTester; - -@Parameters(commandDescription = "MongoDB (experimental)") -public class MongoDBOptions implements DBMSSpecificOptions { - - @Parameter(names = "--test-validation", description = "Enable/Disable validation of schema with Schema Validation", arity = 1) - public boolean testValidation = true; - - @Parameter(names = "--test-null-inserts", description = "Enables to test inserting with null values, validation has to be off", arity = 1) - public boolean testNullInserts; - - @Parameter(names = "--test-random-types", description = "Insert random types instead of schema types, validation has to be off", arity = 1) - public boolean testRandomTypes; - - @Parameter(names = "--max-number-indexes", description = "The maximum number of indexes used.", arity = 1) - public int maxNumberIndexes = 15; - - @Parameter(names = "--test-computed-values", description = "Enable adding computed values to query", arity = 1) - public boolean testComputedValues; - - @Parameter(names = "--test-with-regex", description = "Enable Regex Leaf Nodes", arity = 1) - public boolean testWithRegex; - - @Parameter(names = "--test-with-count", description = "Count the number of documents and check with count command", arity = 1) - public boolean testWithCount; - - @Parameter(names = "--null-safety", description = "", arity = 1) - public boolean nullSafety; - - @Parameter(names = "--oracle") - public List oracles = Arrays.asList(QUERY_PARTITIONING, DOCUMENT_REMOVAL); - - @Override - public List getTestOracleFactory() { - return oracles; - } - - public enum MongoDBOracleFactory implements OracleFactory { - QUERY_PARTITIONING { - @Override - public TestOracle create(MongoDBGlobalState globalState) throws Exception { - List oracles = new ArrayList<>(); - oracles.add(new MongoDBQueryPartitioningWhereTester(globalState)); - return new CompositeTestOracle(oracles, globalState); - } - }, - DOCUMENT_REMOVAL { - @Override - public TestOracle create(MongoDBGlobalState globalState) throws Exception { - List oracles = new ArrayList<>(); - oracles.add(new MongoDBDocumentRemovalTester(globalState)); - return new CompositeTestOracle(oracles, globalState); - } - } - } -} diff --git a/src/sqlancer/mongodb/MongoDBProvider.java b/src/sqlancer/mongodb/MongoDBProvider.java deleted file mode 100644 index 09662a863..000000000 --- a/src/sqlancer/mongodb/MongoDBProvider.java +++ /dev/null @@ -1,128 +0,0 @@ -package sqlancer.mongodb; - -import java.util.ArrayList; -import java.util.List; - -import com.google.auto.service.AutoService; -import com.mongodb.client.MongoClient; -import com.mongodb.client.MongoClients; -import com.mongodb.client.MongoDatabase; - -import sqlancer.AbstractAction; -import sqlancer.DatabaseProvider; -import sqlancer.ExecutionTimer; -import sqlancer.GlobalState; -import sqlancer.IgnoreMeException; -import sqlancer.ProviderAdapter; -import sqlancer.Randomly; -import sqlancer.StatementExecutor; -import sqlancer.common.log.LoggableFactory; -import sqlancer.common.query.Query; -import sqlancer.mongodb.MongoDBSchema.MongoDBTable; -import sqlancer.mongodb.gen.MongoDBIndexGenerator; -import sqlancer.mongodb.gen.MongoDBInsertGenerator; -import sqlancer.mongodb.gen.MongoDBTableGenerator; - -@AutoService(DatabaseProvider.class) -public class MongoDBProvider - extends ProviderAdapter { - - public MongoDBProvider() { - super(MongoDBGlobalState.class, MongoDBOptions.class); - } - - public enum Action implements AbstractAction { - INSERT(MongoDBInsertGenerator::getQuery), CREATE_INDEX(MongoDBIndexGenerator::getQuery); - - private final MongoDBQueryProvider queryProvider; - - Action(MongoDBQueryProvider queryProvider) { - this.queryProvider = queryProvider; - } - - @Override - public Query getQuery(MongoDBGlobalState globalState) throws Exception { - return queryProvider.getQuery(globalState); - } - } - - public static int mapActions(MongoDBGlobalState globalState, Action a) { - Randomly r = globalState.getRandomly(); - switch (a) { - case INSERT: - return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); - case CREATE_INDEX: - return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumberIndexes); - default: - throw new AssertionError(a); - } - } - - public static class MongoDBGlobalState extends GlobalState { - - private final List schemaTables = new ArrayList<>(); - - public void addTable(MongoDBTable table) { - schemaTables.add(table); - } - - @Override - protected void executeEpilogue(Query q, boolean success, ExecutionTimer timer) throws Exception { - boolean logExecutionTime = getOptions().logExecutionTime(); - if (success && getOptions().printSucceedingStatements()) { - System.out.println(q.getLogString()); - } - if (logExecutionTime) { - getLogger().writeCurrent("// " + timer.end().asString()); - } - if (q.couldAffectSchema()) { - updateSchema(); - } - } - - @Override - protected MongoDBSchema readSchema() throws Exception { - return new MongoDBSchema(schemaTables); - } - } - - @Override - public void generateDatabase(MongoDBGlobalState globalState) throws Exception { - for (int i = 0; i < Randomly.fromOptions(4, 5, 6); i++) { - boolean success; - do { - MongoDBQueryAdapter query = new MongoDBTableGenerator(globalState).getQuery(globalState); - success = globalState.executeStatement(query); - } while (!success); - } - StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), - MongoDBProvider::mapActions, (q) -> { - if (globalState.getSchema().getDatabaseTables().isEmpty()) { - throw new IgnoreMeException(); - } - }); - se.executeStatements(); - } - - @Override - public MongoDBConnection createDatabase(MongoDBGlobalState globalState) throws Exception { - MongoClient mongoClient = MongoClients.create(); - MongoDatabase database = mongoClient.getDatabase(globalState.getDatabaseName()); - database.drop(); - return new MongoDBConnection(mongoClient, database); - } - - @Override - public String getDBMSName() { - return "mongodb"; - } - - @Override - public LoggableFactory getLoggableFactory() { - return new MongoDBLoggableFactory(); - } - - @Override - protected void checkViewsAreValid(MongoDBGlobalState globalState) { - } -} diff --git a/src/sqlancer/mongodb/MongoDBQueryAdapter.java b/src/sqlancer/mongodb/MongoDBQueryAdapter.java deleted file mode 100644 index e2add3242..000000000 --- a/src/sqlancer/mongodb/MongoDBQueryAdapter.java +++ /dev/null @@ -1,15 +0,0 @@ -package sqlancer.mongodb; - -import sqlancer.common.query.Query; - -public abstract class MongoDBQueryAdapter extends Query { - @Override - public String getQueryString() { - throw new UnsupportedOperationException(); - } - - @Override - public String getUnterminatedQueryString() { - throw new UnsupportedOperationException(); - } -} diff --git a/src/sqlancer/mongodb/MongoDBQueryProvider.java b/src/sqlancer/mongodb/MongoDBQueryProvider.java deleted file mode 100644 index 970c90cea..000000000 --- a/src/sqlancer/mongodb/MongoDBQueryProvider.java +++ /dev/null @@ -1,6 +0,0 @@ -package sqlancer.mongodb; - -@FunctionalInterface -public interface MongoDBQueryProvider { - MongoDBQueryAdapter getQuery(S globalState) throws Exception; -} diff --git a/src/sqlancer/mongodb/MongoDBSchema.java b/src/sqlancer/mongodb/MongoDBSchema.java deleted file mode 100644 index 5ae3cdd24..000000000 --- a/src/sqlancer/mongodb/MongoDBSchema.java +++ /dev/null @@ -1,97 +0,0 @@ -package sqlancer.mongodb; - -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import org.bson.BsonType; - -import com.mongodb.client.MongoDatabase; - -import sqlancer.Randomly; -import sqlancer.common.schema.AbstractSchema; -import sqlancer.common.schema.AbstractTable; -import sqlancer.common.schema.AbstractTableColumn; -import sqlancer.common.schema.AbstractTables; -import sqlancer.common.schema.TableIndex; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; - -public class MongoDBSchema extends AbstractSchema { - - public enum MongoDBDataType { - INTEGER(BsonType.INT32), STRING(BsonType.STRING), BOOLEAN(BsonType.BOOLEAN), DOUBLE(BsonType.DOUBLE), - DATE_TIME(BsonType.DATE_TIME), TIMESTAMP(BsonType.TIMESTAMP); - - private final BsonType bsonType; - - MongoDBDataType(BsonType type) { - this.bsonType = type; - } - - public BsonType getBsonType() { - return bsonType; - } - - public static MongoDBDataType getRandom(MongoDBGlobalState state) { - Set valueSet = new HashSet<>(Arrays.asList(values())); - if (state.getDbmsSpecificOptions().nullSafety) { - valueSet.remove(STRING); - } - MongoDBDataType[] configuredValues = new MongoDBDataType[valueSet.size()]; - return Randomly.fromOptions(valueSet.toArray(configuredValues)); - } - } - - public static class MongoDBColumn extends AbstractTableColumn { - - private final boolean isId; - private final boolean isNullable; - - public MongoDBColumn(String name, MongoDBDataType type, boolean isId, boolean isNullable) { - super(name, null, type); - this.isId = isId; - this.isNullable = isNullable; - } - - public boolean isId() { - return isId; - } - - public boolean isNullable() { - return isNullable; - } - - } - - public static class MongoDBTables extends AbstractTables { - - public MongoDBTables(List tables) { - super(tables); - } - } - - public MongoDBSchema(List databaseTables) { - super(databaseTables); - } - - public static class MongoDBTable extends AbstractTable { - public MongoDBTable(String name, List columns, boolean isView) { - super(name, columns, Collections.emptyList(), isView); - } - - @Override - public long getNrRows(MongoDBGlobalState globalState) { - throw new UnsupportedOperationException(); - } - } - - public static MongoDBSchema fromConnection(MongoDatabase connection, String databaseName) { - throw new UnsupportedOperationException(); - } - - public MongoDBTables getRandomTableNonEmptyTables() { - return new MongoDBTables(Randomly.nonEmptySubset(getDatabaseTables())); - } -} diff --git a/src/sqlancer/mongodb/ast/MongoDBBinaryComparisonNode.java b/src/sqlancer/mongodb/ast/MongoDBBinaryComparisonNode.java deleted file mode 100644 index 21675250a..000000000 --- a/src/sqlancer/mongodb/ast/MongoDBBinaryComparisonNode.java +++ /dev/null @@ -1,16 +0,0 @@ -package sqlancer.mongodb.ast; - -import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator; - -public class MongoDBBinaryComparisonNode extends NewBinaryOperatorNode { - public MongoDBBinaryComparisonNode(Node left, Node right, - MongoDBBinaryComparisonOperator op) { - super(left, right, op); - } - - public MongoDBBinaryComparisonOperator operator() { - return (MongoDBBinaryComparisonOperator) op; - } -} diff --git a/src/sqlancer/mongodb/ast/MongoDBBinaryLogicalNode.java b/src/sqlancer/mongodb/ast/MongoDBBinaryLogicalNode.java deleted file mode 100644 index efb8d8294..000000000 --- a/src/sqlancer/mongodb/ast/MongoDBBinaryLogicalNode.java +++ /dev/null @@ -1,16 +0,0 @@ -package sqlancer.mongodb.ast; - -import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryLogicalOperator; - -public class MongoDBBinaryLogicalNode extends NewBinaryOperatorNode { - public MongoDBBinaryLogicalNode(Node left, Node right, - MongoDBBinaryLogicalOperator op) { - super(left, right, op); - } - - public MongoDBBinaryLogicalOperator operator() { - return (MongoDBBinaryLogicalOperator) op; - } -} diff --git a/src/sqlancer/mongodb/ast/MongoDBConstant.java b/src/sqlancer/mongodb/ast/MongoDBConstant.java deleted file mode 100644 index 86f783b48..000000000 --- a/src/sqlancer/mongodb/ast/MongoDBConstant.java +++ /dev/null @@ -1,252 +0,0 @@ -package sqlancer.mongodb.ast; - -import java.io.Serializable; - -import org.bson.BsonDateTime; -import org.bson.BsonTimestamp; -import org.bson.Document; - -import sqlancer.common.ast.newast.Node; - -public abstract class MongoDBConstant implements Node { - private MongoDBConstant() { - } - - public abstract void setValueInDocument(Document document, String key); - - public abstract String getLogValue(); - - public abstract Object getValue(); - - public abstract Serializable getSerializedValue(); - - public static class MongoDBNullConstant extends MongoDBConstant { - - @Override - public void setValueInDocument(Document document, String key) { - document.append(key, null); - } - - @Override - public String getLogValue() { - return "null"; - } - - @Override - public Object getValue() { - return null; - } - - @Override - public Serializable getSerializedValue() { - return null; - } - } - - public static Node createNullConstant() { - return new MongoDBNullConstant(); - } - - public static class MongoDBIntegerConstant extends MongoDBConstant { - - private final int value; - - public MongoDBIntegerConstant(int value) { - this.value = value; - } - - @Override - public void setValueInDocument(Document document, String key) { - document.append(key, value); - } - - @Override - public String getLogValue() { - return "NumberInt(" + value + ")"; - } - - @Override - public Integer getValue() { - return value; - } - - @Override - public Serializable getSerializedValue() { - return value; - } - } - - public static Node createIntegerConstant(int value) { - return new MongoDBIntegerConstant(value); - } - - public static class MongoDBStringConstant extends MongoDBConstant { - - private final String value; - - public MongoDBStringConstant(String value) { - this.value = value; - } - - public String getStringValue() { - return value; - } - - @Override - public void setValueInDocument(Document document, String key) { - document.append(key, value); - } - - @Override - public String getLogValue() { - return "\"" + value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n") + "\""; - } - - @Override - public String getValue() { - return value; - } - - @Override - public Serializable getSerializedValue() { - return value; - } - } - - public static Node createStringConstant(String value) { - return new MongoDBStringConstant(value); - } - - public static class MongoDBBooleanConstant extends MongoDBConstant { - - private final boolean value; - - public MongoDBBooleanConstant(boolean value) { - this.value = value; - } - - @Override - public void setValueInDocument(Document document, String key) { - document.append(key, value); - } - - @Override - public String getLogValue() { - return String.valueOf(value); - } - - @Override - public Boolean getValue() { - return value; - } - - @Override - public Serializable getSerializedValue() { - return value; - } - } - - public static Node createBooleanConstant(boolean value) { - return new MongoDBBooleanConstant(value); - } - - public static class MongoDBDoubleConstant extends MongoDBConstant { - - private final double value; - - public MongoDBDoubleConstant(double value) { - this.value = value; - } - - @Override - public void setValueInDocument(Document document, String key) { - document.append(key, value); - } - - @Override - public String getLogValue() { - return String.valueOf(value); - } - - @Override - public Double getValue() { - return value; - } - - @Override - public Serializable getSerializedValue() { - return value; - } - } - - public static Node createDoubleConstant(double value) { - return new MongoDBDoubleConstant(value); - } - - public static class MongoDBDateTimeConstant extends MongoDBConstant { - - private final BsonDateTime value; - - public MongoDBDateTimeConstant(long val) { - this.value = new BsonDateTime(val); - } - - @Override - public void setValueInDocument(Document document, String key) { - document.append(key, value); - } - - @Override - public String getLogValue() { - return "new Date(" + value.getValue() + ")"; - } - - @Override - public BsonDateTime getValue() { - return value; - } - - @Override - public Serializable getSerializedValue() { - return value.getValue(); - } - } - - public static Node createDateTimeConstant(long value) { - return new MongoDBDateTimeConstant(value); - } - - public static class MongoDBTimestampConstant extends MongoDBConstant { - - private final BsonTimestamp value; - - public MongoDBTimestampConstant(long value) { - this.value = new BsonTimestamp(value); - } - - @Override - public void setValueInDocument(Document document, String key) { - document.append(key, value); - } - - @Override - public String getLogValue() { - return "Timestamp(" + value.getValue() + ",1)"; - } - - @Override - public BsonTimestamp getValue() { - return value; - } - - @Override - public Serializable getSerializedValue() { - return value.getValue(); - } - } - - public static Node createTimestampConstant(long value) { - return new MongoDBTimestampConstant(value); - } - -} diff --git a/src/sqlancer/mongodb/ast/MongoDBExpression.java b/src/sqlancer/mongodb/ast/MongoDBExpression.java deleted file mode 100644 index 1235a1fbc..000000000 --- a/src/sqlancer/mongodb/ast/MongoDBExpression.java +++ /dev/null @@ -1,4 +0,0 @@ -package sqlancer.mongodb.ast; - -public interface MongoDBExpression { -} diff --git a/src/sqlancer/mongodb/ast/MongoDBRegexNode.java b/src/sqlancer/mongodb/ast/MongoDBRegexNode.java deleted file mode 100644 index 76c608586..000000000 --- a/src/sqlancer/mongodb/ast/MongoDBRegexNode.java +++ /dev/null @@ -1,24 +0,0 @@ -package sqlancer.mongodb.ast; - -import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBRegexOperator.REGEX; - -import sqlancer.common.ast.newast.NewBinaryOperatorNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBRegexOperator; - -public class MongoDBRegexNode extends NewBinaryOperatorNode { - private final String options; - - public MongoDBRegexNode(Node left, Node right, String options) { - super(left, right, REGEX); - this.options = options; - } - - public String getOptions() { - return options; - } - - public MongoDBRegexOperator operator() { - return (MongoDBRegexOperator) op; - } -} diff --git a/src/sqlancer/mongodb/ast/MongoDBSelect.java b/src/sqlancer/mongodb/ast/MongoDBSelect.java deleted file mode 100644 index 0fe91ba4a..000000000 --- a/src/sqlancer/mongodb/ast/MongoDBSelect.java +++ /dev/null @@ -1,104 +0,0 @@ -package sqlancer.mongodb.ast; - -import java.util.List; - -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.test.MongoDBColumnTestReference; - -public class MongoDBSelect implements Node { - - private final String mainTableName; - private final MongoDBColumnTestReference joinColumn; - List projectionColumns; - List lookupList; - boolean hasFilter; - Node filterClause; - boolean hasComputed; - List> computedClauses; - private boolean withCountClause; - - public MongoDBSelect(String mainTableName, MongoDBColumnTestReference joinColumn) { - this.mainTableName = mainTableName; - this.joinColumn = joinColumn; - } - - public String getMainTableName() { - return mainTableName; - } - - public MongoDBColumnTestReference getJoinColumn() { - return joinColumn; - } - - public void setProjectionList(List fetchColumns) { - if (fetchColumns == null || fetchColumns.isEmpty()) { - throw new IllegalArgumentException(); - } - this.projectionColumns = fetchColumns; - } - - public List getProjectionList() { - if (projectionColumns == null) { - throw new IllegalStateException(); - } - return projectionColumns; - } - - public void setLookupList(List lookupList) { - if (lookupList == null || lookupList.isEmpty()) { - throw new IllegalArgumentException(); - } - this.lookupList = lookupList; - } - - public List getLookupList() { - if (lookupList == null) { - throw new IllegalStateException(); - } - return lookupList; - } - - public void setFilterClause(Node filterClause) { - if (filterClause == null) { - hasFilter = false; - this.filterClause = null; - return; - } - hasFilter = true; - this.filterClause = filterClause; - } - - public Node getFilterClause() { - return filterClause; - } - - public boolean hasFilter() { - return hasFilter; - } - - public void setComputedClause(List> computedClause) { - if (computedClause == null) { - hasComputed = false; - this.computedClauses = null; - return; - } - hasComputed = true; - this.computedClauses = computedClause; - } - - public List> getComputedClause() { - return computedClauses; - } - - public boolean hasComputed() { - return hasComputed; - } - - public boolean getWithCountClause() { - return withCountClause; - } - - public void setWithCountClause(boolean withCountClause) { - this.withCountClause = withCountClause; - } -} diff --git a/src/sqlancer/mongodb/ast/MongoDBUnaryLogicalOperatorNode.java b/src/sqlancer/mongodb/ast/MongoDBUnaryLogicalOperatorNode.java deleted file mode 100644 index a34fe27e5..000000000 --- a/src/sqlancer/mongodb/ast/MongoDBUnaryLogicalOperatorNode.java +++ /dev/null @@ -1,16 +0,0 @@ -package sqlancer.mongodb.ast; - -import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBUnaryLogicalOperator; - -public class MongoDBUnaryLogicalOperatorNode extends NewUnaryPrefixOperatorNode { - - public MongoDBUnaryLogicalOperatorNode(Node expr, MongoDBUnaryLogicalOperator op) { - super(expr, op); - } - - public MongoDBUnaryLogicalOperator operator() { - return (MongoDBUnaryLogicalOperator) op; - } -} diff --git a/src/sqlancer/mongodb/ast/MongoDBUnsupportedPredicate.java b/src/sqlancer/mongodb/ast/MongoDBUnsupportedPredicate.java deleted file mode 100644 index eae143e7d..000000000 --- a/src/sqlancer/mongodb/ast/MongoDBUnsupportedPredicate.java +++ /dev/null @@ -1,7 +0,0 @@ -package sqlancer.mongodb.ast; - -import sqlancer.common.ast.newast.Node; - -public class MongoDBUnsupportedPredicate implements Node { - -} diff --git a/src/sqlancer/mongodb/gen/MongoDBComputedExpressionGenerator.java b/src/sqlancer/mongodb/gen/MongoDBComputedExpressionGenerator.java deleted file mode 100644 index 347d5bc97..000000000 --- a/src/sqlancer/mongodb/gen/MongoDBComputedExpressionGenerator.java +++ /dev/null @@ -1,89 +0,0 @@ -package sqlancer.mongodb.gen; - -import java.util.ArrayList; -import java.util.List; - -import sqlancer.Randomly; -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.gen.UntypedExpressionGenerator; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.MongoDBSchema; -import sqlancer.mongodb.ast.MongoDBExpression; -import sqlancer.mongodb.test.MongoDBColumnTestReference; - -public class MongoDBComputedExpressionGenerator - extends UntypedExpressionGenerator, MongoDBColumnTestReference> { - - private final MongoDBGlobalState globalState; - - @Override - public Node generateLeafNode() { - ComputedFunction function = ComputedFunction.getRandom(); - List> expressions = new ArrayList<>(); - for (int i = 0; i < function.getNrArgs(); i++) { - expressions.add(super.generateLeafNode()); - } - return new NewFunctionNode<>(expressions, function); - } - - @Override - protected Node generateExpression(int depth) { - if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { - return generateLeafNode(); - } - ComputedFunction func = ComputedFunction.getRandom(); - return new NewFunctionNode<>(generateExpressions(func.getNrArgs(), depth + 1), func); - } - - public MongoDBComputedExpressionGenerator(MongoDBGlobalState globalState) { - this.globalState = globalState; - } - - public enum ComputedFunction { - ADD(2, "$add"), MULTIPLY(2, "$multiply"), DIVIDE(2, "$divide"), POW(2, "$pow"), SQRT(1, "$sqrt"), - LOG(2, "$log"), AVG(2, "$avg"), EXP(1, "$exp"); - - private final int nrArgs; - private final String operatorName; - - ComputedFunction(int nrArgs, String operatorName) { - this.nrArgs = nrArgs; - this.operatorName = operatorName; - } - - public static ComputedFunction getRandom() { - return Randomly.fromOptions(values()); - } - - public int getNrArgs() { - return nrArgs; - } - - public String getOperator() { - return operatorName; - } - } - - @Override - public Node generateConstant() { - MongoDBSchema.MongoDBDataType type = MongoDBSchema.MongoDBDataType.getRandom(globalState); - MongoDBConstantGenerator generator = new MongoDBConstantGenerator(globalState); - return generator.generateConstantWithType(type); - } - - @Override - protected Node generateColumn() { - return Randomly.fromList(columns); - } - - @Override - public Node negatePredicate(Node predicate) { - throw new UnsupportedOperationException(); - } - - @Override - public Node isNull(Node expr) { - throw new UnsupportedOperationException(); - } -} diff --git a/src/sqlancer/mongodb/gen/MongoDBConstantGenerator.java b/src/sqlancer/mongodb/gen/MongoDBConstantGenerator.java deleted file mode 100644 index 2e6b15048..000000000 --- a/src/sqlancer/mongodb/gen/MongoDBConstantGenerator.java +++ /dev/null @@ -1,86 +0,0 @@ -package sqlancer.mongodb.gen; - -import org.bson.Document; - -import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.MongoDBSchema.MongoDBDataType; -import sqlancer.mongodb.ast.MongoDBConstant; -import sqlancer.mongodb.ast.MongoDBConstant.MongoDBBooleanConstant; -import sqlancer.mongodb.ast.MongoDBConstant.MongoDBDateTimeConstant; -import sqlancer.mongodb.ast.MongoDBConstant.MongoDBDoubleConstant; -import sqlancer.mongodb.ast.MongoDBConstant.MongoDBIntegerConstant; -import sqlancer.mongodb.ast.MongoDBConstant.MongoDBNullConstant; -import sqlancer.mongodb.ast.MongoDBConstant.MongoDBTimestampConstant; -import sqlancer.mongodb.ast.MongoDBExpression; - -public class MongoDBConstantGenerator { - private final MongoDBGlobalState globalState; - - public MongoDBConstantGenerator(MongoDBGlobalState globalState) { - this.globalState = globalState; - } - - public Node generateConstantWithType(MongoDBDataType option) { - switch (option) { - case DATE_TIME: - return MongoDBConstant.createDateTimeConstant(globalState.getRandomly().getInteger()); - case BOOLEAN: - return MongoDBConstant.createBooleanConstant(Randomly.getBoolean()); - case DOUBLE: - return MongoDBConstant.createDoubleConstant(globalState.getRandomly().getDouble()); - case STRING: - return MongoDBConstant.createStringConstant(globalState.getRandomly().getString()); - case INTEGER: - return MongoDBConstant.createIntegerConstant((int) globalState.getRandomly().getInteger()); - case TIMESTAMP: - return MongoDBConstant.createTimestampConstant(globalState.getRandomly().getInteger()); - default: - throw new AssertionError(option); - } - } - - public void addRandomConstant(Document document, String key) { - MongoDBDataType type = MongoDBDataType.getRandom(globalState); - addRandomConstantWithType(document, key, type); - } - - public void addRandomConstantWithType(Document document, String key, MongoDBDataType option) { - MongoDBConstant constant; - if (globalState.getDbmsSpecificOptions().testNullInserts && Randomly.getBooleanWithSmallProbability()) { - constant = new MongoDBNullConstant(); - constant.setValueInDocument(document, key); - return; - } - switch (option) { - case DATE_TIME: - constant = new MongoDBDateTimeConstant(globalState.getRandomly().getInteger()); - constant.setValueInDocument(document, key); - return; - - case BOOLEAN: - constant = new MongoDBBooleanConstant(Randomly.getBoolean()); - constant.setValueInDocument(document, key); - return; - case DOUBLE: - constant = new MongoDBDoubleConstant(globalState.getRandomly().getDouble()); - constant.setValueInDocument(document, key); - return; - case STRING: - constant = new MongoDBConstant.MongoDBStringConstant(globalState.getRandomly().getString()); - constant.setValueInDocument(document, key); - return; - case INTEGER: - constant = new MongoDBIntegerConstant((int) globalState.getRandomly().getInteger()); - constant.setValueInDocument(document, key); - return; - case TIMESTAMP: - constant = new MongoDBTimestampConstant(globalState.getRandomly().getInteger()); - constant.setValueInDocument(document, key); - return; - default: - throw new AssertionError(option); - } - } -} diff --git a/src/sqlancer/mongodb/gen/MongoDBIndexGenerator.java b/src/sqlancer/mongodb/gen/MongoDBIndexGenerator.java deleted file mode 100644 index 8687fd45c..000000000 --- a/src/sqlancer/mongodb/gen/MongoDBIndexGenerator.java +++ /dev/null @@ -1,25 +0,0 @@ -package sqlancer.mongodb.gen; - -import java.util.List; - -import sqlancer.Randomly; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.MongoDBQueryAdapter; -import sqlancer.mongodb.MongoDBSchema.MongoDBColumn; -import sqlancer.mongodb.MongoDBSchema.MongoDBTable; -import sqlancer.mongodb.query.MongoDBCreateIndexQuery; - -public final class MongoDBIndexGenerator { - private MongoDBIndexGenerator() { - } - - public static MongoDBQueryAdapter getQuery(MongoDBGlobalState globalState) { - MongoDBTable randomTable = globalState.getSchema().getRandomTable(); - List columns = Randomly.nonEmptySubset(randomTable.getColumns()); - MongoDBCreateIndexQuery createIndexQuery = new MongoDBCreateIndexQuery(randomTable); - for (MongoDBColumn column : columns) { - createIndexQuery.addIndex(column.getName(), Randomly.getBoolean()); - } - return createIndexQuery; - } -} diff --git a/src/sqlancer/mongodb/gen/MongoDBInsertGenerator.java b/src/sqlancer/mongodb/gen/MongoDBInsertGenerator.java deleted file mode 100644 index f8b8b3ffd..000000000 --- a/src/sqlancer/mongodb/gen/MongoDBInsertGenerator.java +++ /dev/null @@ -1,38 +0,0 @@ -package sqlancer.mongodb.gen; - -import org.bson.Document; - -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.MongoDBQueryAdapter; -import sqlancer.mongodb.MongoDBSchema.MongoDBTable; -import sqlancer.mongodb.query.MongoDBInsertQuery; - -public final class MongoDBInsertGenerator { - - private final MongoDBGlobalState globalState; - - private MongoDBInsertGenerator(MongoDBGlobalState globalState) { - this.globalState = globalState; - } - - public static MongoDBQueryAdapter getQuery(MongoDBGlobalState globalState) { - return new MongoDBInsertGenerator(globalState).generate(); - } - - public MongoDBQueryAdapter generate() { - Document result = new Document(); - MongoDBTable table = globalState.getSchema().getRandomTable(); - MongoDBConstantGenerator constantGenerator = new MongoDBConstantGenerator(globalState); - - for (int i = 0; i < table.getColumns().size(); i++) { - if (!globalState.getDbmsSpecificOptions().testRandomTypes) { - constantGenerator.addRandomConstantWithType(result, table.getColumns().get(i).getName(), - table.getColumns().get(i).getType()); - } else { - constantGenerator.addRandomConstant(result, table.getColumns().get(i).getName()); - } - } - - return new MongoDBInsertQuery(table, result); - } -} diff --git a/src/sqlancer/mongodb/gen/MongoDBMatchExpressionGenerator.java b/src/sqlancer/mongodb/gen/MongoDBMatchExpressionGenerator.java deleted file mode 100644 index 3e6c833d3..000000000 --- a/src/sqlancer/mongodb/gen/MongoDBMatchExpressionGenerator.java +++ /dev/null @@ -1,291 +0,0 @@ -package sqlancer.mongodb.gen; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.bson.conversions.Bson; - -import com.mongodb.client.model.Filters; - -import sqlancer.Randomly; -import sqlancer.common.ast.BinaryOperatorNode.Operator; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.gen.UntypedExpressionGenerator; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.MongoDBSchema.MongoDBDataType; -import sqlancer.mongodb.ast.MongoDBBinaryComparisonNode; -import sqlancer.mongodb.ast.MongoDBBinaryLogicalNode; -import sqlancer.mongodb.ast.MongoDBConstant; -import sqlancer.mongodb.ast.MongoDBExpression; -import sqlancer.mongodb.ast.MongoDBRegexNode; -import sqlancer.mongodb.ast.MongoDBUnaryLogicalOperatorNode; -import sqlancer.mongodb.ast.MongoDBUnsupportedPredicate; -import sqlancer.mongodb.test.MongoDBColumnTestReference; -import sqlancer.mongodb.visitor.MongoDBNegateVisitor; - -public class MongoDBMatchExpressionGenerator - extends UntypedExpressionGenerator, MongoDBColumnTestReference> { - - private final MongoDBGlobalState globalState; - - private enum LeafExpression { - BINARY_COMPARISON, REGEX - } - - private enum NonLeafExpression { - BINARY_LOGICAL, UNARY_LOGICAL - } - - public MongoDBMatchExpressionGenerator(MongoDBGlobalState globalState) { - this.globalState = globalState; - } - - @Override - public Node generateLeafNode() { - List possibleOptions = new ArrayList<>(Arrays.asList(LeafExpression.values())); - if (!globalState.getDbmsSpecificOptions().testWithRegex) { - possibleOptions.remove(LeafExpression.REGEX); - } - LeafExpression expr = Randomly.fromList(possibleOptions); - switch (expr) { - case BINARY_COMPARISON: - MongoDBBinaryComparisonOperator operator = MongoDBBinaryComparisonOperator.getRandom(); - MongoDBColumnTestReference reference = (MongoDBColumnTestReference) generateColumn(); - - return new MongoDBBinaryComparisonNode(reference, - generateConstant(reference.getColumnReference().getType()), operator); - case REGEX: - return new MongoDBRegexNode(generateColumn(), - new MongoDBConstantGenerator(globalState).generateConstantWithType(MongoDBDataType.STRING), - getRandomizedRegexOptions()); - default: - throw new AssertionError(); - } - } - - @Override - protected Node generateExpression(int depth) { - if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { - return generateLeafNode(); - } - - List possibleOptions = new ArrayList<>(Arrays.asList(NonLeafExpression.values())); - NonLeafExpression expr = Randomly.fromList(possibleOptions); - switch (expr) { - case BINARY_LOGICAL: - MongoDBBinaryLogicalOperator binaryOperator = MongoDBBinaryLogicalOperator.getRandom(); - return new MongoDBBinaryLogicalNode(generateExpression(depth + 1), generateExpression(depth + 1), - binaryOperator); - case UNARY_LOGICAL: - MongoDBUnaryLogicalOperator unaryOperator = MongoDBUnaryLogicalOperator.getRandom(); - return new MongoDBUnaryLogicalOperatorNode(generateExpression(depth + 1), unaryOperator); - default: - throw new AssertionError(); - } - } - - @Override - public Node generateConstant() { - MongoDBDataType type = MongoDBDataType.getRandom(globalState); - MongoDBConstantGenerator generator = new MongoDBConstantGenerator(globalState); - if (Randomly.getBooleanWithSmallProbability()) { - return MongoDBConstant.createNullConstant(); - } - return generator.generateConstantWithType(type); - } - - public Node generateConstant(MongoDBDataType type) { - MongoDBConstantGenerator generator = new MongoDBConstantGenerator(globalState); - if (Randomly.getBooleanWithSmallProbability() && !globalState.getDbmsSpecificOptions().nullSafety) { - return MongoDBConstant.createNullConstant(); - } - return generator.generateConstantWithType(type); - } - - private String getRandomizedRegexOptions() { - List s = Randomly.subset("i", "m", "x", "s"); - return String.join("", s); - } - - @Override - protected Node generateColumn() { - return Randomly.fromList(columns); - } - - @Override - public Node generatePredicate() { - Node result = super.generatePredicate(); - return MongoDBNegateVisitor.cleanNegations(result); - } - - @Override - public Node negatePredicate(Node predicate) { - Node result = new MongoDBUnaryLogicalOperatorNode(predicate, - MongoDBUnaryLogicalOperator.NOT); - return MongoDBNegateVisitor.cleanNegations(result); - } - - @Override - public Node isNull(Node expr) { - return new MongoDBUnsupportedPredicate<>(); - } - - public enum MongoDBUnaryLogicalOperator implements Operator { - NOT { - @Override - public Bson applyOperator(Bson inner) { - return Filters.not(inner); - } - - @Override - public String getTextRepresentation() { - return "$not"; - } - }; - - public abstract Bson applyOperator(Bson inner); - - public static MongoDBUnaryLogicalOperator getRandom() { - return Randomly.fromOptions(values()); - } - } - - public enum MongoDBBinaryLogicalOperator implements Operator { - AND { - @Override - public Bson applyOperator(Bson left, Bson right) { - return Filters.and(left, right); - } - - @Override - public String getTextRepresentation() { - return "$and"; - } - }, - OR { - @Override - public Bson applyOperator(Bson left, Bson right) { - return Filters.or(left, right); - } - - @Override - public String getTextRepresentation() { - return "$or"; - } - }, - NOR { - @Override - public Bson applyOperator(Bson left, Bson right) { - return Filters.nor(left, right); - } - - @Override - public String getTextRepresentation() { - return "$nor"; - } - }; - - public abstract Bson applyOperator(Bson left, Bson right); - - public static MongoDBBinaryLogicalOperator getRandom() { - return Randomly.fromOptions(values()); - } - } - - public enum MongoDBBinaryComparisonOperator implements Operator { - EQUALS { - @Override - public Bson applyOperator(String columnName, MongoDBConstant constant) { - return Filters.eq(columnName, constant.getValue()); - } - - @Override - public String getTextRepresentation() { - return "$eq"; - } - }, - NOT_EQUALS { - @Override - public Bson applyOperator(String columnName, MongoDBConstant constant) { - return Filters.ne(columnName, constant.getValue()); - } - - @Override - public String getTextRepresentation() { - return "$ne"; - } - }, - GREATER { - @Override - public Bson applyOperator(String columnName, MongoDBConstant constant) { - return Filters.gt(columnName, constant.getValue()); - } - - @Override - public String getTextRepresentation() { - return "$gt"; - } - - }, - LESS { - @Override - public Bson applyOperator(String columnName, MongoDBConstant constant) { - return Filters.lt(columnName, constant.getValue()); - } - - @Override - public String getTextRepresentation() { - return "$lt"; - } - - }, - GREATER_EQUAL { - @Override - public Bson applyOperator(String columnName, MongoDBConstant constant) { - return Filters.gte(columnName, constant.getValue()); - - } - - @Override - public String getTextRepresentation() { - return "$gte"; - } - - }, - LESS_EQUAL { - @Override - public Bson applyOperator(String columnName, MongoDBConstant constant) { - return Filters.lte(columnName, constant.getValue()); - } - - @Override - public String getTextRepresentation() { - return "$lte"; - } - }; - - public abstract Bson applyOperator(String columnName, MongoDBConstant constant); - - public static MongoDBBinaryComparisonOperator getRandom() { - return Randomly.fromOptions(values()); - } - } - - public enum MongoDBRegexOperator implements Operator { - REGEX { - @Override - public Bson applyOperator(String columnName, MongoDBConstant.MongoDBStringConstant regex, String options) { - return Filters.regex(columnName, regex.getStringValue(), options); - } - - @Override - public String getTextRepresentation() { - return "$regex"; - } - }; - - public abstract Bson applyOperator(String columnName, MongoDBConstant.MongoDBStringConstant regex, - String options); - } -} diff --git a/src/sqlancer/mongodb/gen/MongoDBTableGenerator.java b/src/sqlancer/mongodb/gen/MongoDBTableGenerator.java deleted file mode 100644 index 0153a6334..000000000 --- a/src/sqlancer/mongodb/gen/MongoDBTableGenerator.java +++ /dev/null @@ -1,54 +0,0 @@ -package sqlancer.mongodb.gen; - -import java.util.ArrayList; -import java.util.List; - -import sqlancer.Randomly; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.MongoDBQueryAdapter; -import sqlancer.mongodb.MongoDBSchema.MongoDBColumn; -import sqlancer.mongodb.MongoDBSchema.MongoDBDataType; -import sqlancer.mongodb.MongoDBSchema.MongoDBTable; -import sqlancer.mongodb.query.MongoDBCreateTableQuery; - -public class MongoDBTableGenerator { - - private MongoDBTable table; - private final List columnsToBeAdded = new ArrayList<>(); - private final MongoDBGlobalState state; - - public MongoDBTableGenerator(MongoDBGlobalState state) { - this.state = state; - } - - public MongoDBQueryAdapter getQuery(MongoDBGlobalState globalState) { - String tableName = globalState.getSchema().getFreeTableName(); - MongoDBCreateTableQuery createTableQuery = new MongoDBCreateTableQuery(tableName); - table = new MongoDBTable(tableName, columnsToBeAdded, false); - for (int i = 0; i < Randomly.smallNumber() + 1; i++) { - String columnName = String.format("c%d", i); - MongoDBDataType type = createColumn(columnName); - if (globalState.getDbmsSpecificOptions().testValidation) { - createTableQuery.addValidation(columnName, type.getBsonType()); - } - } - globalState.addTable(table); - return createTableQuery; - } - - private MongoDBDataType createColumn(String columnName) { - MongoDBDataType columnType = MongoDBDataType.getRandom(state); - MongoDBColumn newColumn = new MongoDBColumn(columnName, columnType, false, false); - newColumn.setTable(table); - columnsToBeAdded.add(newColumn); - return columnType; - } - - public String getTableName() { - return table.getName(); - } - - public MongoDBTable getGeneratedTable() { - return table; - } -} diff --git a/src/sqlancer/mongodb/query/MongoDBCreateIndexQuery.java b/src/sqlancer/mongodb/query/MongoDBCreateIndexQuery.java deleted file mode 100644 index c873b5924..000000000 --- a/src/sqlancer/mongodb/query/MongoDBCreateIndexQuery.java +++ /dev/null @@ -1,77 +0,0 @@ -package sqlancer.mongodb.query; - -import java.util.ArrayList; -import java.util.List; - -import org.bson.conversions.Bson; - -import com.mongodb.client.model.Indexes; - -import sqlancer.GlobalState; -import sqlancer.Main; -import sqlancer.common.query.ExpectedErrors; -import sqlancer.mongodb.MongoDBConnection; -import sqlancer.mongodb.MongoDBQueryAdapter; -import sqlancer.mongodb.MongoDBSchema.MongoDBTable; - -public class MongoDBCreateIndexQuery extends MongoDBQueryAdapter { - - private final MongoDBTable table; - private final List indeces; - private final List logIndeces; - - public MongoDBCreateIndexQuery(MongoDBTable table) { - this.table = table; - this.indeces = new ArrayList<>(); - this.logIndeces = new ArrayList<>(); - } - - public void addIndex(String column, boolean ascending) { - if (ascending) { - indeces.add(Indexes.ascending(column)); - logIndeces.add(column + ": 1"); - } else { - indeces.add(Indexes.descending(column)); - logIndeces.add(column + ": -1"); - } - } - - @Override - public String getLogString() { - StringBuilder sb = new StringBuilder(); - sb.append("db.").append(table.getName()).append(".createIndex({"); - String helper = ""; - for (String index : logIndeces) { - sb.append(helper); - helper = ","; - sb.append(index); - } - sb.append("})\n"); - return sb.toString(); - } - - @Override - public boolean couldAffectSchema() { - return false; - } - - @Override - public > boolean execute(G globalState, String... fills) - throws Exception { - Main.nrSuccessfulActions.addAndGet(1); - Bson index; - if (indeces.size() > 1) { - index = Indexes.compoundIndex(indeces); - } else { - index = indeces.get(0); - } - globalState.getConnection().getDatabase().getCollection(table.getName()).createIndex(index); - return true; - } - - @Override - public ExpectedErrors getExpectedErrors() { - return new ExpectedErrors(); - } - -} diff --git a/src/sqlancer/mongodb/query/MongoDBCreateTableQuery.java b/src/sqlancer/mongodb/query/MongoDBCreateTableQuery.java deleted file mode 100644 index 7bc174c77..000000000 --- a/src/sqlancer/mongodb/query/MongoDBCreateTableQuery.java +++ /dev/null @@ -1,115 +0,0 @@ -package sqlancer.mongodb.query; - -import java.util.ArrayList; -import java.util.List; - -import org.bson.BsonType; -import org.bson.conversions.Bson; - -import com.mongodb.client.model.CreateCollectionOptions; -import com.mongodb.client.model.Filters; -import com.mongodb.client.model.ValidationOptions; - -import sqlancer.GlobalState; -import sqlancer.Main; -import sqlancer.common.query.ExpectedErrors; -import sqlancer.mongodb.MongoDBConnection; -import sqlancer.mongodb.MongoDBQueryAdapter; - -public class MongoDBCreateTableQuery extends MongoDBQueryAdapter { - - private final String tableName; - private Bson validationFilter; - private final List logRequiredList; - private final List logPropertiesList; - - public MongoDBCreateTableQuery(String tableName) { - this.tableName = tableName; - this.validationFilter = null; - logRequiredList = new ArrayList<>(); - logPropertiesList = new ArrayList<>(); - } - - @Override - public boolean couldAffectSchema() { - return true; - } - - @Override - public > boolean execute(G globalState, String... fills) - throws Exception { - ValidationOptions collOptions = new ValidationOptions().validator(this.validationFilter); - Main.nrSuccessfulActions.addAndGet(1); - globalState.getConnection().getDatabase().createCollection(tableName, - new CreateCollectionOptions().validationOptions(collOptions)); - return true; - } - - @Override - public ExpectedErrors getExpectedErrors() { - return new ExpectedErrors(); - } - - @Override - public String getLogString() { - String helper = ""; - StringBuilder sb = new StringBuilder(); - sb.append("db.createCollection(\"").append(tableName).append("\", {\n"); - - if (!logPropertiesList.isEmpty()) { - sb.append("validator: {"); - sb.append("$jsonSchema: {"); - sb.append("bsonType:\"object\","); - sb.append("required: [\n"); - for (String req : logRequiredList) { - sb.append(helper); - helper = ","; - sb.append(req); - } - sb.append("],"); - sb.append("properties: {\n"); - for (String prop : logPropertiesList) { - sb.append(prop); - } - sb.append("}}}})"); - } else { - sb.append("})"); - } - - return sb.toString(); - } - - public void addValidation(String columnName, BsonType type) { - Bson nameFilter = Filters.exists(columnName); - Bson typeFilter = Filters.type(columnName, type); - - if (validationFilter == null) { - validationFilter = Filters.and(nameFilter, typeFilter); - } else { - validationFilter = Filters.and(validationFilter, Filters.and(nameFilter, typeFilter)); - } - - logRequiredList.add("\"" + columnName + "\""); - logPropertiesList.add(columnName + ": { bsonType:\"" + bsonTypeToString(type) + "\"},\n"); - } - - public String bsonTypeToString(BsonType type) { - switch (type) { - case DOUBLE: - return "double"; - case STRING: - return "string"; - case BOOLEAN: - return "bool"; - case INT32: - case INT64: - return "int"; - case DATE_TIME: - return "date"; - case TIMESTAMP: - return "timestamp"; - default: - throw new IllegalStateException(); - } - } -} diff --git a/src/sqlancer/mongodb/query/MongoDBInsertQuery.java b/src/sqlancer/mongodb/query/MongoDBInsertQuery.java deleted file mode 100644 index 127dc82d0..000000000 --- a/src/sqlancer/mongodb/query/MongoDBInsertQuery.java +++ /dev/null @@ -1,87 +0,0 @@ -package sqlancer.mongodb.query; - -import org.bson.BsonDateTime; -import org.bson.BsonTimestamp; -import org.bson.Document; -import org.bson.types.ObjectId; - -import com.mongodb.client.result.InsertOneResult; - -import sqlancer.GlobalState; -import sqlancer.Main; -import sqlancer.common.query.ExpectedErrors; -import sqlancer.mongodb.MongoDBConnection; -import sqlancer.mongodb.MongoDBQueryAdapter; -import sqlancer.mongodb.MongoDBSchema.MongoDBTable; - -public class MongoDBInsertQuery extends MongoDBQueryAdapter { - boolean excluded; - private final MongoDBTable table; - private final Document documentToBeInserted; - - public MongoDBInsertQuery(MongoDBTable table, Document documentToBeInserted) { - this.table = table; - this.documentToBeInserted = documentToBeInserted; - this.excluded = false; - } - - @Override - public String getLogString() { - StringBuilder sb = new StringBuilder(); - sb.append("db." + table.getName() + ".insert({"); - String helper = ""; - for (String key : documentToBeInserted.keySet()) { - sb.append(helper); - helper = ", "; - if (documentToBeInserted.get(key) instanceof ObjectId) { - continue; - } - Object value = documentToBeInserted.get(key); - sb.append(key); - sb.append(": "); - sb.append(getStringRepresentation(value)); - } - sb.append("})\n"); - - return sb.toString(); - } - - private String getStringRepresentation(Object value) { - if (value instanceof Double) { - return String.valueOf(value); - } else if (value instanceof Integer) { - return "NumberInt(" + value + ")"; - } else if (value instanceof String) { - return "\"" + value + "\""; - } else if (value instanceof BsonDateTime) { - return "new Date(" + ((BsonDateTime) value).getValue() + ")"; - } else if (value instanceof BsonTimestamp) { - return "Timestamp(" + ((BsonTimestamp) value).getValue() + ",1)"; - } else if (value instanceof Boolean) { - return String.valueOf(value); - } else if (value == null) { - return "null"; - } else { - throw new IllegalStateException(); - } - } - - @Override - public boolean couldAffectSchema() { - return true; - } - - @Override - public > boolean execute(G globalState, String... fills) - throws Exception { - Main.nrSuccessfulActions.addAndGet(1); - InsertOneResult result = globalState.getConnection().getDatabase().getCollection(table.getName()) - .insertOne(documentToBeInserted); - return result.wasAcknowledged(); - } - - @Override - public ExpectedErrors getExpectedErrors() { - return new ExpectedErrors(); - } -} diff --git a/src/sqlancer/mongodb/query/MongoDBRemoveQuery.java b/src/sqlancer/mongodb/query/MongoDBRemoveQuery.java deleted file mode 100644 index 6fe1c9e3f..000000000 --- a/src/sqlancer/mongodb/query/MongoDBRemoveQuery.java +++ /dev/null @@ -1,59 +0,0 @@ -package sqlancer.mongodb.query; - -import org.bson.Document; -import org.bson.types.ObjectId; - -import com.mongodb.client.result.DeleteResult; - -import sqlancer.GlobalState; -import sqlancer.Main; -import sqlancer.common.query.ExpectedErrors; -import sqlancer.mongodb.MongoDBConnection; -import sqlancer.mongodb.MongoDBQueryAdapter; -import sqlancer.mongodb.MongoDBSchema; - -public class MongoDBRemoveQuery extends MongoDBQueryAdapter { - - private final String objectId; - private final MongoDBSchema.MongoDBTable table; - - public MongoDBRemoveQuery(MongoDBSchema.MongoDBTable table, String objectId) { - this.objectId = objectId; - this.table = table; - } - - @Override - public boolean couldAffectSchema() { - return true; - } - - @Override - public > boolean execute(G globalState, String... fills) - throws Exception { - try { - DeleteResult result = globalState.getConnection().getDatabase().getCollection(table.getName()) - .deleteOne(new Document("_id", new ObjectId(objectId))); - if (result.wasAcknowledged()) { - Main.nrSuccessfulActions.addAndGet(1); - } else { - Main.nrUnsuccessfulActions.addAndGet(1); - } - return result.wasAcknowledged(); - } catch (Exception e) { - Main.nrUnsuccessfulActions.addAndGet(1); - return false; - } - } - - @Override - public ExpectedErrors getExpectedErrors() { - return new ExpectedErrors(); - } - - @Override - public String getLogString() { - StringBuilder stringBuilder = new StringBuilder(); - stringBuilder.append("db.").append(table.getName()).append(".remove({'_id': '").append(objectId).append("'})"); - return stringBuilder.toString(); - } -} diff --git a/src/sqlancer/mongodb/query/MongoDBSelectQuery.java b/src/sqlancer/mongodb/query/MongoDBSelectQuery.java deleted file mode 100644 index 1288f114c..000000000 --- a/src/sqlancer/mongodb/query/MongoDBSelectQuery.java +++ /dev/null @@ -1,147 +0,0 @@ -package sqlancer.mongodb.query; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import org.bson.Document; -import org.bson.conversions.Bson; - -import com.mongodb.client.MongoCollection; -import com.mongodb.client.MongoCursor; - -import sqlancer.GlobalState; -import sqlancer.common.query.ExpectedErrors; -import sqlancer.common.query.SQLancerResultSet; -import sqlancer.mongodb.MongoDBConnection; -import sqlancer.mongodb.MongoDBQueryAdapter; -import sqlancer.mongodb.ast.MongoDBExpression; -import sqlancer.mongodb.ast.MongoDBSelect; -import sqlancer.mongodb.visitor.MongoDBVisitor; - -public class MongoDBSelectQuery extends MongoDBQueryAdapter { - - private final MongoDBSelect select; - - private List resultSet; - - public MongoDBSelectQuery(MongoDBSelect select) { - this.select = select; - } - - @Override - public boolean couldAffectSchema() { - return false; - } - - @Override - public > boolean execute(G globalState, String... fills) - throws Exception { - throw new UnsupportedOperationException(); - } - - @Override - public ExpectedErrors getExpectedErrors() { - ExpectedErrors errors = new ExpectedErrors(); - // ARITHMETIC - errors.add("Failed to optimize pipeline :: caused by :: Can't coerce out of range value"); - errors.add("Can't coerce out of range value"); - errors.add("date overflow in $add"); - errors.add("Failed to optimize pipeline :: caused by :: $sqrt only supports numeric types, not"); - errors.add("Failed to optimize pipeline :: caused by :: $sqrt's argument must be greater than or equal to 0"); - errors.add("Failed to optimize pipeline :: caused by :: $pow's base must be numeric, not"); - errors.add("Failed to optimize pipeline :: caused by :: $pow cannot take a base of 0 and a negative exponent"); - errors.add("Failed to optimize pipeline :: caused by :: $add only supports numeric or date types, not"); - errors.add("Failed to optimize pipeline :: caused by :: $exp only supports numeric types, not"); - errors.add("Failed to optimize pipeline :: caused by :: $log's base must be numeric, not"); - errors.add("Failed to optimize pipeline :: caused by :: $log's base must be a positive number not equal to 1"); - errors.add("Failed to optimize pipeline :: caused by :: $multiply only supports numeric types, not"); - errors.add("$log's argument must be numeric, not"); - errors.add("$log's argument must be a positive number, but"); - errors.add("$log's base must be numeric, not"); - errors.add("$log's base must be a positive number not equal to 1"); - errors.add("$divide only supports numeric types, not"); - errors.add("can't $divide by zero"); - errors.add("$pow's exponent must be numeric, not"); - errors.add("$pow's base must be numeric, not"); - errors.add("$pow cannot take a base of 0 and a negative exponent"); - errors.add("$add only supports numeric or date types, not"); - errors.add("only one date allowed in an $add expression"); - errors.add("$multiply only supports numeric types, not"); - errors.add("$exp only supports numeric types, not"); - errors.add("$sqrt's argument must be greater than or equal to 0"); - errors.add("$sqrt only supports numeric types, not"); - - // REGEX - errors.add("Regular expression is invalid: nothing to repeat"); - errors.add("Regular expression is invalid: missing terminating ] for character class"); - errors.add("Regular expression is invalid: unmatched parentheses"); - errors.add("Regular expression is invalid: missing )"); - errors.add("Regular expression is invalid: invalid UTF-8 string"); - errors.add("Regular expression is invalid: \\k is not followed by a braced, angle-bracketed, or quoted name"); - errors.add("Regular expression is invalid: missing opening brace after \\\\o"); - errors.add("Regular expression is invalid: reference to non-existent subpattern"); - errors.add("Regular expression is invalid: \\ at end of pattern"); - errors.add("Regular expression is invalid: PCRE does not support \\L, \\l, \\N{name}, \\U, or \\u"); - errors.add("Regular expression is invalid: (?R or (?[+-]digits must be followed by )"); - errors.add("Regular expression is invalid: unknown property name after \\P or \\p"); - errors.add("Regular expression is invalid: (*VERB) not recognized or malformed"); - errors.add("Regular expression is invalid: a numbered reference must not be zero"); - errors.add("Regular expression is invalid: unrecognized character after (? or (?-"); - errors.add("Regular expression is invalid: \\c at end of pattern"); - errors.add("Regular expression is invalid: malformed \\P or \\p sequence"); - errors.add("Regular expression is invalid: range out of order in character class"); - errors.add("Regular expression is invalid: group name must start with a non-digit"); - errors.add("Regular expression is invalid: \\c must be followed by an ASCII character"); - errors.add("Regular expression is invalid: subpattern name expected"); - errors.add("Regular expression is invalid: POSIX collating elements are not supported"); - errors.add("Regular expression is invalid: closing ) for (?C expected"); - errors.add("Regular expression is invalid: syntax error in subpattern name (missing terminator)"); - errors.add("Regular expression is invalid: \\\\N is not supported in a class"); - errors.add("Regular expression is invalid: non-octal character in \\o{} (closing brace missing?)"); - errors.add("Regular expression is invalid: non-hex character in \\x{} (closing brace missing?)"); - errors.add( - "Regular expression is invalid: \\g is not followed by a braced, angle-bracketed, or quoted name/number or by a plain number"); - errors.add("Regular expression is invalid: digits missing in \\x{} or \\o{}"); - errors.add("Regular expression is invalid: malformed number or name after (?("); - errors.add("Regular expression is invalid: digit expected after (?+"); - errors.add("Regular expression is invalid: assertion expected after (?( or (?(?C)"); - errors.add("Regular expression is invalid: unrecognized character after (?P"); - - return errors; - } - - @Override - public > SQLancerResultSet executeAndGet(G globalState, - String... fills) throws Exception { - if (globalState.getOptions().logEachSelect()) { - globalState.getLogger().writeCurrent(this.getLogString()); - try { - globalState.getLogger().getCurrentFileWriter().flush(); - } catch (IOException e) { - e.printStackTrace(); - } - } - List pipeline = MongoDBVisitor.asQuery(select); - - MongoCollection collection = globalState.getConnection().getDatabase() - .getCollection(select.getMainTableName()); - MongoCursor cursor = collection.aggregate(pipeline).cursor(); - resultSet = new ArrayList<>(); - while (cursor.hasNext()) { - Document document = cursor.next(); - resultSet.add(document); - } - return null; - } - - @Override - public String getLogString() { - return MongoDBVisitor.asStringLog(select); - } - - public List getResultSet() { - return resultSet; - } - -} diff --git a/src/sqlancer/mongodb/test/MongoDBColumnTestReference.java b/src/sqlancer/mongodb/test/MongoDBColumnTestReference.java deleted file mode 100644 index 59a2a6724..000000000 --- a/src/sqlancer/mongodb/test/MongoDBColumnTestReference.java +++ /dev/null @@ -1,40 +0,0 @@ -package sqlancer.mongodb.test; - -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.MongoDBSchema.MongoDBColumn; -import sqlancer.mongodb.ast.MongoDBExpression; - -public class MongoDBColumnTestReference implements Node { - - private final MongoDBColumn columnReference; - private final boolean inMainTable; - - public MongoDBColumnTestReference(MongoDBColumn columnReference, boolean inMainTable) { - this.columnReference = columnReference; - this.inMainTable = inMainTable; - } - - public String getQueryString() { - if (inMainTable) { - return this.columnReference.getName(); - } else { - return "join_" + this.columnReference.getTable().getName() + "." + this.columnReference.getName(); - } - } - - public boolean inMainTable() { - return inMainTable; - } - - public String getTableName() { - return this.columnReference.getTable().getName(); - } - - public String getPlainName() { - return this.columnReference.getName(); - } - - public MongoDBColumn getColumnReference() { - return columnReference; - } -} diff --git a/src/sqlancer/mongodb/test/MongoDBDocumentRemovalBase.java b/src/sqlancer/mongodb/test/MongoDBDocumentRemovalBase.java deleted file mode 100644 index 6faae42bc..000000000 --- a/src/sqlancer/mongodb/test/MongoDBDocumentRemovalBase.java +++ /dev/null @@ -1,89 +0,0 @@ -package sqlancer.mongodb.test; - -import java.util.ArrayList; -import java.util.List; - -import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.gen.ExpressionGenerator; -import sqlancer.common.oracle.DocumentRemovalOracleBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.mongodb.MongoDBProvider; -import sqlancer.mongodb.MongoDBSchema; -import sqlancer.mongodb.ast.MongoDBExpression; -import sqlancer.mongodb.ast.MongoDBSelect; -import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator; -import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator; - -public class MongoDBDocumentRemovalBase extends - DocumentRemovalOracleBase, MongoDBProvider.MongoDBGlobalState> implements TestOracle { - - protected MongoDBSchema schema; - protected MongoDBSchema.MongoDBTables targetTables; - protected MongoDBSchema.MongoDBTable mainTable; - protected List targetColumns; - protected MongoDBMatchExpressionGenerator expressionGenerator; - protected MongoDBSelect select; - - protected MongoDBDocumentRemovalBase(MongoDBProvider.MongoDBGlobalState state) { - super(state); - } - - @Override - public void check() throws Exception { - schema = state.getSchema(); - targetTables = schema.getRandomTableNonEmptyTables(); - mainTable = targetTables.getTables().get(0); - generateTargetColumns(); - expressionGenerator = new MongoDBMatchExpressionGenerator(state).setColumns(targetColumns); - initializeDocumentRemovalOracle(); - select = new MongoDBSelect<>(mainTable.getName(), targetColumns.get(0)); - select.setProjectionList(targetColumns); - if (Randomly.getBooleanWithRatherLowProbability()) { - select.setLookupList(targetColumns); - } else { - select.setLookupList(Randomly.nonEmptySubset(targetColumns)); - } - if (state.getDbmsSpecificOptions().testComputedValues) { - generateComputedColumns(); - } - } - - private void generateTargetColumns() { - targetColumns = new ArrayList<>(); - for (MongoDBSchema.MongoDBColumn c : mainTable.getColumns()) { - targetColumns.add(new MongoDBColumnTestReference(c, true)); - } - List joinsOtherTables = new ArrayList<>(); - if (!state.getDbmsSpecificOptions().nullSafety) { - for (int i = 1; i < targetTables.getTables().size(); i++) { - MongoDBSchema.MongoDBTable procTable = targetTables.getTables().get(i); - for (MongoDBSchema.MongoDBColumn c : procTable.getColumns()) { - joinsOtherTables.add(new MongoDBColumnTestReference(c, false)); - } - } - } - if (!joinsOtherTables.isEmpty()) { - int randNumber = state.getRandomly().getInteger(1, Math.min(joinsOtherTables.size(), 4)); - List subsetJoinsOtherTables = Randomly.nonEmptySubset(joinsOtherTables, - randNumber); - targetColumns.addAll(subsetJoinsOtherTables); - } - } - - private void generateComputedColumns() { - List> computedColumns = new ArrayList<>(); - int numberComputedColumns = state.getRandomly().getInteger(1, 4); - MongoDBComputedExpressionGenerator generator = new MongoDBComputedExpressionGenerator(state) - .setColumns(targetColumns); - for (int i = 0; i < numberComputedColumns; i++) { - computedColumns.add(generator.generateExpression()); - } - select.setComputedClause(computedColumns); - } - - @Override - protected ExpressionGenerator> getGen() { - return expressionGenerator; - } -} diff --git a/src/sqlancer/mongodb/test/MongoDBDocumentRemovalTester.java b/src/sqlancer/mongodb/test/MongoDBDocumentRemovalTester.java deleted file mode 100644 index ece3193d1..000000000 --- a/src/sqlancer/mongodb/test/MongoDBDocumentRemovalTester.java +++ /dev/null @@ -1,49 +0,0 @@ -package sqlancer.mongodb.test; - -import static sqlancer.mongodb.MongoDBComparatorHelper.getResultSetAsDocumentList; - -import java.util.List; - -import org.bson.Document; - -import sqlancer.Randomly; -import sqlancer.mongodb.MongoDBProvider; -import sqlancer.mongodb.MongoDBQueryAdapter; -import sqlancer.mongodb.gen.MongoDBInsertGenerator; -import sqlancer.mongodb.query.MongoDBRemoveQuery; -import sqlancer.mongodb.query.MongoDBSelectQuery; - -public class MongoDBDocumentRemovalTester extends MongoDBDocumentRemovalBase { - public MongoDBDocumentRemovalTester(MongoDBProvider.MongoDBGlobalState state) { - super(state); - } - - @Override - public void check() throws Exception { - super.check(); - - select.setWithCountClause(false); - - select.setFilterClause(predicate); - MongoDBSelectQuery selectQuery = new MongoDBSelectQuery(select); - List firstResultSet = getResultSetAsDocumentList(selectQuery, state); - if (firstResultSet == null || firstResultSet.isEmpty()) { - return; - } - - Document documentToRemove = Randomly.fromList(firstResultSet); - MongoDBRemoveQuery removeQuery = new MongoDBRemoveQuery(mainTable, documentToRemove.get("_id").toString()); - state.executeStatement(removeQuery); - - selectQuery = new MongoDBSelectQuery(select); - List secondResultSet = getResultSetAsDocumentList(selectQuery, state); - - MongoDBQueryAdapter insertQuery = MongoDBInsertGenerator.getQuery(state); - state.executeStatement(insertQuery); - - if (secondResultSet.size() + 1 != firstResultSet.size()) { - String assertMessage = "The Result Sizes mismatches!"; - throw new AssertionError(assertMessage); - } - } -} diff --git a/src/sqlancer/mongodb/test/MongoDBQueryPartitioningBase.java b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningBase.java deleted file mode 100644 index 30ebe56bd..000000000 --- a/src/sqlancer/mongodb/test/MongoDBQueryPartitioningBase.java +++ /dev/null @@ -1,92 +0,0 @@ -package sqlancer.mongodb.test; - -import java.util.ArrayList; -import java.util.List; - -import sqlancer.Randomly; -import sqlancer.common.ast.newast.Node; -import sqlancer.common.gen.ExpressionGenerator; -import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.MongoDBSchema; -import sqlancer.mongodb.MongoDBSchema.MongoDBColumn; -import sqlancer.mongodb.MongoDBSchema.MongoDBTable; -import sqlancer.mongodb.MongoDBSchema.MongoDBTables; -import sqlancer.mongodb.ast.MongoDBExpression; -import sqlancer.mongodb.ast.MongoDBSelect; -import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator; -import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator; - -public class MongoDBQueryPartitioningBase - extends TernaryLogicPartitioningOracleBase, MongoDBGlobalState> implements TestOracle { - - protected MongoDBSchema schema; - protected MongoDBTables targetTables; - protected MongoDBTable mainTable; - protected List targetColumns; - protected MongoDBMatchExpressionGenerator expressionGenerator; - protected MongoDBSelect select; - - public MongoDBQueryPartitioningBase(MongoDBGlobalState state) { - super(state); - } - - @Override - public void check() throws Exception { - schema = state.getSchema(); - targetTables = schema.getRandomTableNonEmptyTables(); - mainTable = targetTables.getTables().get(0); - generateTargetColumns(); - expressionGenerator = new MongoDBMatchExpressionGenerator(state).setColumns(targetColumns); - initializeTernaryPredicateVariants(); - select = new MongoDBSelect<>(mainTable.getName(), targetColumns.get(0)); - select.setProjectionList(targetColumns); - if (Randomly.getBooleanWithRatherLowProbability()) { - select.setLookupList(targetColumns); - } else { - select.setLookupList(Randomly.nonEmptySubset(targetColumns)); - } - if (state.getDbmsSpecificOptions().testComputedValues) { - generateComputedColumns(); - } - } - - private void generateComputedColumns() { - List> computedColumns = new ArrayList<>(); - int numberComputedColumns = state.getRandomly().getInteger(1, 4); - MongoDBComputedExpressionGenerator generator = new MongoDBComputedExpressionGenerator(state) - .setColumns(targetColumns); - for (int i = 0; i < numberComputedColumns; i++) { - computedColumns.add(generator.generateExpression()); - } - select.setComputedClause(computedColumns); - } - - private void generateTargetColumns() { - targetColumns = new ArrayList<>(); - for (MongoDBColumn c : mainTable.getColumns()) { - targetColumns.add(new MongoDBColumnTestReference(c, true)); - } - List joinsOtherTables = new ArrayList<>(); - if (!state.getDbmsSpecificOptions().nullSafety) { - for (int i = 1; i < targetTables.getTables().size(); i++) { - MongoDBTable procTable = targetTables.getTables().get(i); - for (MongoDBColumn c : procTable.getColumns()) { - joinsOtherTables.add(new MongoDBColumnTestReference(c, false)); - } - } - } - if (!joinsOtherTables.isEmpty()) { - int randNumber = state.getRandomly().getInteger(1, Math.min(joinsOtherTables.size(), 4)); - List subsetJoinsOtherTables = Randomly.nonEmptySubset(joinsOtherTables, - randNumber); - targetColumns.addAll(subsetJoinsOtherTables); - } - } - - @Override - protected ExpressionGenerator> getGen() { - return expressionGenerator; - } -} diff --git a/src/sqlancer/mongodb/test/MongoDBQueryPartitioningWhereTester.java b/src/sqlancer/mongodb/test/MongoDBQueryPartitioningWhereTester.java deleted file mode 100644 index 53f625d9b..000000000 --- a/src/sqlancer/mongodb/test/MongoDBQueryPartitioningWhereTester.java +++ /dev/null @@ -1,48 +0,0 @@ -package sqlancer.mongodb.test; - -import static sqlancer.mongodb.MongoDBComparatorHelper.getResultSetAsDocumentList; - -import java.util.List; - -import org.bson.Document; - -import sqlancer.mongodb.MongoDBComparatorHelper; -import sqlancer.mongodb.MongoDBProvider.MongoDBGlobalState; -import sqlancer.mongodb.query.MongoDBSelectQuery; - -public class MongoDBQueryPartitioningWhereTester extends MongoDBQueryPartitioningBase { - public MongoDBQueryPartitioningWhereTester(MongoDBGlobalState state) { - super(state); - } - - @Override - public void check() throws Exception { - super.check(); - - select.setWithCountClause(false); - - select.setFilterClause(null); - MongoDBSelectQuery q = new MongoDBSelectQuery(select); - List firstResultSet = getResultSetAsDocumentList(q, state); - - select.setFilterClause(predicate); - q = new MongoDBSelectQuery(select); - List secondResultSet = getResultSetAsDocumentList(q, state); - - select.setFilterClause(negatedPredicate); - q = new MongoDBSelectQuery(select); - List thirdResultSet = getResultSetAsDocumentList(q, state); - - if (state.getDbmsSpecificOptions().testWithCount) { - select.setWithCountClause(true); - select.setFilterClause(predicate); - q = new MongoDBSelectQuery(select); - List forthResultSet = getResultSetAsDocumentList(q, state); - MongoDBComparatorHelper.assumeCountIsEqual(secondResultSet, forthResultSet, q); - } - - secondResultSet.addAll(thirdResultSet); - MongoDBComparatorHelper.assumeResultSetsAreEqual(firstResultSet, secondResultSet, q); - - } -} diff --git a/src/sqlancer/mongodb/visitor/MongoDBNegateVisitor.java b/src/sqlancer/mongodb/visitor/MongoDBNegateVisitor.java deleted file mode 100644 index 39b607f67..000000000 --- a/src/sqlancer/mongodb/visitor/MongoDBNegateVisitor.java +++ /dev/null @@ -1,161 +0,0 @@ -package sqlancer.mongodb.visitor; - -import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryLogicalOperator.AND; -import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryLogicalOperator.NOR; -import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBBinaryLogicalOperator.OR; -import static sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator.MongoDBUnaryLogicalOperator.NOT; - -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.ast.MongoDBBinaryComparisonNode; -import sqlancer.mongodb.ast.MongoDBBinaryLogicalNode; -import sqlancer.mongodb.ast.MongoDBConstant; -import sqlancer.mongodb.ast.MongoDBExpression; -import sqlancer.mongodb.ast.MongoDBRegexNode; -import sqlancer.mongodb.ast.MongoDBSelect; -import sqlancer.mongodb.ast.MongoDBUnaryLogicalOperatorNode; -import sqlancer.mongodb.gen.MongoDBMatchExpressionGenerator; - -public class MongoDBNegateVisitor extends MongoDBVisitor { - - private boolean negate; - Node negatedExpression; - - public MongoDBNegateVisitor(boolean negate) { - this.negate = negate; - } - - @Override - public void visit(Node expr) { - if (expr instanceof MongoDBConstant) { - visit((MongoDBConstant) expr); - } else if (expr instanceof MongoDBSelect) { - visit((MongoDBSelect) expr); - } else if (expr instanceof MongoDBBinaryComparisonNode) { - visit((MongoDBBinaryComparisonNode) expr); - } else if (expr instanceof MongoDBUnaryLogicalOperatorNode) { - visit((MongoDBUnaryLogicalOperatorNode) expr); - } else if (expr instanceof MongoDBRegexNode) { - visit((MongoDBRegexNode) expr); - } else if (expr instanceof MongoDBBinaryLogicalNode) { - visit((MongoDBBinaryLogicalNode) expr); - } else { - throw new AssertionError(expr.getClass()); - } - } - - public void visit(MongoDBBinaryComparisonNode expr) { - - if (negate) { - negatedExpression = new MongoDBUnaryLogicalOperatorNode(expr, NOT); - switch (expr.operator()) { - case EQUALS: - negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), - MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.NOT_EQUALS); - break; - case NOT_EQUALS: - negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), - MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.EQUALS); - break; - case LESS: - negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), - MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.GREATER_EQUAL); - break; - case LESS_EQUAL: - negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), - MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.GREATER); - break; - case GREATER: - negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), - MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.LESS_EQUAL); - break; - case GREATER_EQUAL: - negatedExpression = new MongoDBBinaryComparisonNode(expr.getLeft(), expr.getRight(), - MongoDBMatchExpressionGenerator.MongoDBBinaryComparisonOperator.LESS); - break; - default: - throw new UnsupportedOperationException(); - } - } else { - negatedExpression = expr; - } - } - - public void visit(MongoDBRegexNode expr) { - if (negate) { - negatedExpression = new MongoDBUnaryLogicalOperatorNode(expr, NOT); - } else { - negatedExpression = expr; - } - } - - public void visit(MongoDBUnaryLogicalOperatorNode expr) { - if (!(expr.operator().equals(NOT))) { - throw new UnsupportedOperationException(); - } - negate = !negate; - visit(expr.getExpr()); - } - - public void visit(MongoDBBinaryLogicalNode expr) { - boolean saveNegate = negate; - Node left; - Node right; - switch (expr.operator()) { - case OR: - negate = false; - visit(expr.getLeft()); - left = negatedExpression; - negate = false; - visit(expr.getRight()); - right = negatedExpression; - if (saveNegate) { - negatedExpression = new MongoDBBinaryLogicalNode(left, right, NOR); - } else { - negatedExpression = new MongoDBBinaryLogicalNode(left, right, OR); - } - break; - case AND: - negate = saveNegate; - visit(expr.getLeft()); - left = negatedExpression; - negate = saveNegate; - visit(expr.getRight()); - right = negatedExpression; - if (saveNegate) { - negatedExpression = new MongoDBBinaryLogicalNode(left, right, OR); - } else { - negatedExpression = new MongoDBBinaryLogicalNode(left, right, AND); - } - break; - case NOR: - negate = false; - visit(expr.getLeft()); - left = negatedExpression; - negate = false; - visit(expr.getRight()); - right = negatedExpression; - if (saveNegate) { - negatedExpression = new MongoDBBinaryLogicalNode(left, right, OR); - } else { - negatedExpression = new MongoDBBinaryLogicalNode(left, right, NOR); - } - break; - default: - throw new UnsupportedOperationException(expr.getOperatorRepresentation()); - } - } - - @Override - public void visit(MongoDBConstant c) { - negatedExpression = c; - } - - @Override - public void visit(MongoDBSelect s) { - throw new UnsupportedOperationException(); - } - - public Node getNegatedExpression() { - return negatedExpression; - } -} diff --git a/src/sqlancer/mongodb/visitor/MongoDBToLogVisitor.java b/src/sqlancer/mongodb/visitor/MongoDBToLogVisitor.java deleted file mode 100644 index 4c55e17a6..000000000 --- a/src/sqlancer/mongodb/visitor/MongoDBToLogVisitor.java +++ /dev/null @@ -1,195 +0,0 @@ -package sqlancer.mongodb.visitor; - -import java.util.ArrayList; -import java.util.List; - -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.ast.MongoDBBinaryComparisonNode; -import sqlancer.mongodb.ast.MongoDBBinaryLogicalNode; -import sqlancer.mongodb.ast.MongoDBConstant; -import sqlancer.mongodb.ast.MongoDBExpression; -import sqlancer.mongodb.ast.MongoDBRegexNode; -import sqlancer.mongodb.ast.MongoDBSelect; -import sqlancer.mongodb.ast.MongoDBUnaryLogicalOperatorNode; -import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator.ComputedFunction; -import sqlancer.mongodb.test.MongoDBColumnTestReference; - -public class MongoDBToLogVisitor extends MongoDBVisitor { - - private String mainTableName; - private List lookups; - private String filter; - private String projects; - private boolean hasFilter; - private boolean withCount; - - public String visitLog(Node expr) { - if (expr instanceof MongoDBUnaryLogicalOperatorNode) { - return visit((MongoDBUnaryLogicalOperatorNode) expr); - } else if (expr instanceof MongoDBBinaryLogicalNode) { - return visit((MongoDBBinaryLogicalNode) expr); - } else if (expr instanceof MongoDBBinaryComparisonNode) { - return visit((MongoDBBinaryComparisonNode) expr); - } else if (expr instanceof MongoDBRegexNode) { - return visit((MongoDBRegexNode) expr); - } else { - throw new AssertionError(expr.getClass()); - } - } - - public String visitComputed(Node expr) { - if (expr instanceof NewFunctionNode) { - return visitComputed((NewFunctionNode) expr); - } else { - throw new AssertionError(expr.getClass()); - } - } - - public String visitComputed(NewFunctionNode expr) { - List arguments = new ArrayList<>(); - for (int i = 0; i < expr.getArgs().size(); i++) { - if (expr.getArgs().get(i) instanceof MongoDBConstant) { - arguments.add(((MongoDBConstant) expr.getArgs().get(i)).getLogValue()); - continue; - } - if (expr.getArgs().get(i) instanceof MongoDBColumnTestReference) { - arguments.add("\"$" + ((MongoDBColumnTestReference) expr.getArgs().get(i)).getQueryString() + "\""); - continue; - } - if (expr.getArgs().get(i) instanceof NewFunctionNode) { - arguments.add(visitComputed((NewFunctionNode) expr.getArgs().get(i))); - } else { - throw new AssertionError(); - } - } - if (!(expr.getFunc() instanceof ComputedFunction)) { - throw new AssertionError(expr.getClass()); - } - - StringBuilder sb = new StringBuilder(); - sb.append("{"); - sb.append(((ComputedFunction) expr.getFunc()).getOperator()); - sb.append(": ["); - String helper = ""; - for (String arg : arguments) { - sb.append(helper); - helper = ", "; - sb.append(arg); - } - sb.append("]}"); - return sb.toString(); - } - - public String visit(MongoDBUnaryLogicalOperatorNode expr) { - String inner = visitLog(expr.getExpr()); - return "{ " + expr.operator().getTextRepresentation() + ": [" + inner + "]}"; - } - - public String visit(MongoDBBinaryLogicalNode expr) { - String left = visitLog(expr.getLeft()); - String right = visitLog(expr.getRight()); - - return "{" + expr.operator().getTextRepresentation() + ":[" + left + "," + right + "]}"; - } - - public String visit(MongoDBBinaryComparisonNode expr) { - Node left = expr.getLeft(); - Node right = expr.getRight(); - assert left instanceof MongoDBColumnTestReference; - assert right instanceof MongoDBConstant; - - return "{\"" + ((MongoDBColumnTestReference) left).getQueryString() + "\": {" - + expr.operator().getTextRepresentation() + ": " + ((MongoDBConstant) right).getLogValue() + "}}"; - } - - public String visit(MongoDBRegexNode expr) { - Node left = expr.getLeft(); - Node right = expr.getRight(); - - return "{\"" + ((MongoDBColumnTestReference) left).getQueryString() + "\": {" - + expr.operator().getTextRepresentation() + ": \'" - + ((MongoDBConstant.MongoDBStringConstant) right).getStringValue() + "\', $options: \'" - + expr.getOptions() + "\'}}"; - } - - @Override - public void visit(MongoDBConstant c) { - throw new UnsupportedOperationException(); - } - - @Override - public void visit(MongoDBSelect select) { - hasFilter = select.hasFilter(); - mainTableName = select.getMainTableName(); - setLookups(select); - if (hasFilter) { - setFilter(select); - } - setProjects(select); - withCount = select.getWithCountClause(); - } - - private void setFilter(MongoDBSelect select) { - filter = visitLog(select.getFilterClause()); - } - - private void setLookups(MongoDBSelect select) { - lookups = new ArrayList<>(); - for (MongoDBColumnTestReference testReference : select.getLookupList()) { - if (testReference.inMainTable()) { - continue; - } - String newLookup = "{ $lookup: { from: \"" + testReference.getTableName() + "\", localField: \"" - + select.getJoinColumn().getPlainName() + "\", foreignField: \"" + testReference.getPlainName() - + "\", as: \"" + testReference.getQueryString() + "\"}},\n"; - lookups.add(newLookup); - } - } - - private void setProjects(MongoDBSelect select) { - StringBuilder sb = new StringBuilder(); - sb.append("{"); - String helper = ""; - for (MongoDBColumnTestReference reference : select.getProjectionList()) { - sb.append(helper); - helper = ","; - sb.append("\"").append(reference.getQueryString()).append("\"").append(": 1"); - } - sb.append("\n"); - if (select.hasComputed()) { - String name = "computed"; - int number = 0; - for (Node expressionNode : select.getComputedClause()) { - sb.append(helper); - helper = ",\n"; - sb.append("\"" + name + number + "\": " + visitComputed(expressionNode)); - number++; - } - } - sb.append("}"); - projects = sb.toString(); - } - - public String getStringLog() { - StringBuilder sb = new StringBuilder(); - sb.append("db.").append(mainTableName).append(".aggregate([\n"); - for (String lookup : lookups) { - sb.append(lookup); - } - if (hasFilter) { - sb.append("{ $match: "); - sb.append(filter); - sb.append("},\n"); - } - sb.append("{ $project : "); - sb.append(projects); - sb.append("}"); - if (withCount) { - sb.append(",\n"); - sb.append(" {$count: \"count\"}\n"); - } - sb.append("])\n"); - return sb.toString(); - } -} diff --git a/src/sqlancer/mongodb/visitor/MongoDBToQueryVisitor.java b/src/sqlancer/mongodb/visitor/MongoDBToQueryVisitor.java deleted file mode 100644 index 8efbf97e4..000000000 --- a/src/sqlancer/mongodb/visitor/MongoDBToQueryVisitor.java +++ /dev/null @@ -1,184 +0,0 @@ -package sqlancer.mongodb.visitor; - -import static com.mongodb.client.model.Aggregates.match; -import static com.mongodb.client.model.Aggregates.project; -import static com.mongodb.client.model.Projections.fields; -import static com.mongodb.client.model.Projections.include; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -import org.bson.Document; -import org.bson.conversions.Bson; - -import com.mongodb.client.model.Aggregates; -import com.mongodb.client.model.Projections; - -import sqlancer.common.ast.newast.NewFunctionNode; -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.ast.MongoDBBinaryComparisonNode; -import sqlancer.mongodb.ast.MongoDBBinaryLogicalNode; -import sqlancer.mongodb.ast.MongoDBConstant; -import sqlancer.mongodb.ast.MongoDBConstant.MongoDBStringConstant; -import sqlancer.mongodb.ast.MongoDBExpression; -import sqlancer.mongodb.ast.MongoDBRegexNode; -import sqlancer.mongodb.ast.MongoDBSelect; -import sqlancer.mongodb.ast.MongoDBUnaryLogicalOperatorNode; -import sqlancer.mongodb.gen.MongoDBComputedExpressionGenerator.ComputedFunction; -import sqlancer.mongodb.test.MongoDBColumnTestReference; - -public class MongoDBToQueryVisitor extends MongoDBVisitor { - - private List lookup; - private Bson filter; - private Bson projection; - private Bson count; - private boolean hasFilter; - private boolean hasCountClause; - - public Bson visitBson(Node expr) { - if (expr instanceof MongoDBUnaryLogicalOperatorNode) { - return visit((MongoDBUnaryLogicalOperatorNode) expr); - } else if (expr instanceof MongoDBBinaryLogicalNode) { - return visit((MongoDBBinaryLogicalNode) expr); - } else if (expr instanceof MongoDBBinaryComparisonNode) { - return visit((MongoDBBinaryComparisonNode) expr); - } else if (expr instanceof MongoDBRegexNode) { - return visit((MongoDBRegexNode) expr); - } else { - throw new AssertionError(expr.getClass()); - } - } - - public Document visitComputed(Node expr) { - if (expr instanceof NewFunctionNode) { - return visitComputed((NewFunctionNode) expr); - } else { - throw new AssertionError(expr.getClass()); - } - } - - public Document visitComputed(NewFunctionNode expr) { - List visitedArgs = new ArrayList<>(); - for (int i = 0; i < expr.getArgs().size(); i++) { - if (expr.getArgs().get(i) instanceof MongoDBConstant) { - visitedArgs.add(((MongoDBConstant) expr.getArgs().get(i)).getSerializedValue()); - continue; - } - if (expr.getArgs().get(i) instanceof MongoDBColumnTestReference) { - visitedArgs.add("$" + ((MongoDBColumnTestReference) expr.getArgs().get(i)).getQueryString()); - continue; - } - if (expr.getArgs().get(i) instanceof NewFunctionNode) { - visitedArgs.add(visitComputed((NewFunctionNode) expr.getArgs().get(i))); - } else { - throw new AssertionError(); - } - } - if (expr.getFunc() instanceof ComputedFunction) { - return new Document(((ComputedFunction) expr.getFunc()).getOperator(), visitedArgs); - } else { - throw new AssertionError(expr.getClass()); - } - - } - - public Bson visit(MongoDBUnaryLogicalOperatorNode expr) { - Bson inner = visitBson(expr.getExpr()); - return expr.operator().applyOperator(inner); - } - - public Bson visit(MongoDBBinaryLogicalNode expr) { - Bson left = visitBson(expr.getLeft()); - Bson right = visitBson(expr.getRight()); - return expr.operator().applyOperator(left, right); - } - - public Bson visit(MongoDBRegexNode expr) { - Node left = expr.getLeft(); - Node right = expr.getRight(); - - String columnName = ((MongoDBColumnTestReference) left).getQueryString(); - - return expr.operator().applyOperator(columnName, (MongoDBStringConstant) right, expr.getOptions()); - } - - public Bson visit(MongoDBBinaryComparisonNode expr) { - Node left = expr.getLeft(); - Node right = expr.getRight(); - assert left instanceof MongoDBColumnTestReference; - assert right instanceof MongoDBConstant; - - String columnName = ((MongoDBColumnTestReference) left).getQueryString(); - return expr.operator().applyOperator(columnName, (MongoDBConstant) right); - } - - @Override - public void visit(MongoDBConstant c) { - throw new UnsupportedOperationException(); - } - - @Override - public void visit(MongoDBSelect select) { - hasFilter = select.hasFilter(); - setLookup(select); - if (hasFilter) { - setFilter(select); - } - setProjection(select); - hasCountClause = select.getWithCountClause(); - if (hasCountClause) { - setCount(); - } - } - - private void setCount() { - count = Aggregates.count("count"); - } - - private void setFilter(MongoDBSelect select) { - filter = match(this.visitBson(select.getFilterClause())); - } - - private void setLookup(MongoDBSelect select) { - lookup = new ArrayList<>(); - for (MongoDBColumnTestReference reference : select.getLookupList()) { - if (reference.inMainTable()) { - continue; - } - lookup.add(Aggregates.lookup(reference.getTableName(), select.getJoinColumn().getPlainName(), - reference.getPlainName(), reference.getQueryString())); - } - } - - private void setProjection(MongoDBSelect select) { - List stringProjects = new ArrayList<>(); - for (MongoDBColumnTestReference ref : select.getProjectionList()) { - stringProjects.add(ref.getQueryString()); - } - List projections = new ArrayList<>(); - projections.add(include(stringProjects)); - if (select.hasComputed()) { - String name = "computed"; - int number = 0; - for (Node expressionNode : select.getComputedClause()) { - projections.add(Projections.computed(name + number, visitComputed(expressionNode))); - number++; - } - } - projection = project(fields(projections)); - } - - public List getPipeline() { - List result = new ArrayList<>(lookup); - if (hasFilter) { - result.add(filter); - } - result.add(projection); - if (hasCountClause) { - result.add(count); - } - return result; - } -} diff --git a/src/sqlancer/mongodb/visitor/MongoDBVisitor.java b/src/sqlancer/mongodb/visitor/MongoDBVisitor.java deleted file mode 100644 index e02a50f02..000000000 --- a/src/sqlancer/mongodb/visitor/MongoDBVisitor.java +++ /dev/null @@ -1,45 +0,0 @@ -package sqlancer.mongodb.visitor; - -import java.util.List; - -import org.bson.conversions.Bson; - -import sqlancer.common.ast.newast.Node; -import sqlancer.mongodb.ast.MongoDBConstant; -import sqlancer.mongodb.ast.MongoDBExpression; -import sqlancer.mongodb.ast.MongoDBSelect; - -public abstract class MongoDBVisitor { - - public abstract void visit(MongoDBConstant c); - - public abstract void visit(MongoDBSelect s); - - public void visit(Node expr) { - if (expr instanceof MongoDBConstant) { - visit((MongoDBConstant) expr); - } else if (expr instanceof MongoDBSelect) { - visit((MongoDBSelect) expr); - } else { - throw new AssertionError(expr.getClass()); - } - } - - public static List asQuery(Node expr) { - MongoDBToQueryVisitor visitor = new MongoDBToQueryVisitor(); - visitor.visit(expr); - return visitor.getPipeline(); - } - - public static String asStringLog(Node expr) { - MongoDBToLogVisitor visitor = new MongoDBToLogVisitor(); - visitor.visit(expr); - return visitor.getStringLog(); - } - - public static Node cleanNegations(Node expr) { - MongoDBNegateVisitor visitor = new MongoDBNegateVisitor(false); - visitor.visit(expr); - return visitor.getNegatedExpression(); - } -} diff --git a/src/sqlancer/mysql/MySQLBugs.java b/src/sqlancer/mysql/MySQLBugs.java index 9cd87614d..8cb8a3391 100644 --- a/src/sqlancer/mysql/MySQLBugs.java +++ b/src/sqlancer/mysql/MySQLBugs.java @@ -19,6 +19,24 @@ public final class MySQLBugs { // https://bugs.mysql.com/bug.php?id=99135 public static boolean bug99135 = true; + // https://bugs.mysql.com/bug.php?id=111471 + public static boolean bug111471 = true; + + // https://bugs.mysql.com/bug.php?id=112242 + public static boolean bug112242 = true; + + // https://bugs.mysql.com/bug.php?id=112243 + public static boolean bug112243 = true; + + // https://bugs.mysql.com/bug.php?id=112264 + public static boolean bug112264 = true; + + // https://bugs.mysql.com/bug.php?id=114533 + public static boolean bug114533 = true; + + // https://bugs.mysql.com/bug.php?id=114534 + public static boolean bug114534 = true; + private MySQLBugs() { } diff --git a/src/sqlancer/mysql/MySQLErrors.java b/src/sqlancer/mysql/MySQLErrors.java index 39aa33e1e..989c8fed6 100644 --- a/src/sqlancer/mysql/MySQLErrors.java +++ b/src/sqlancer/mysql/MySQLErrors.java @@ -1,5 +1,9 @@ package sqlancer.mysql; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + import sqlancer.common.query.ExpectedErrors; public final class MySQLErrors { @@ -7,9 +11,56 @@ public final class MySQLErrors { private MySQLErrors() { } - public static void addExpressionErrors(ExpectedErrors errors) { + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("BIGINT value is out of range"); // e.g., CAST(-('-1e500') AS SIGNED) errors.add("is not valid for CHARACTER SET"); + + if (MySQLBugs.bug111471) { + errors.add("Memory capacity exceeded"); + } + + return errors; + } + + public static List getExpressionRegexErrors() { + ArrayList errors = new ArrayList<>(); + + if (MySQLBugs.bug114533) { + errors.add(Pattern.compile("For input string: \"0+-0\"")); // match: For input string: + // "00000000000000000000-0" + } + + errors.add(Pattern.compile("Unknown column '.*' in 'order clause'")); + errors.add(Pattern.compile("Unknown column '.*' in 'EXISTS subquery'")); + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + errors.addAllRegexes(getExpressionRegexErrors()); + } + + public static List getInsertUpdateErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("doesn't have a default value"); + errors.add("Data truncation"); + errors.add("Incorrect integer value"); + errors.add("Duplicate entry"); + errors.add("Data truncated for column"); + errors.add("Data truncated for functional index"); + errors.add("cannot be null"); + errors.add("Incorrect decimal value"); + errors.add("The value specified for generated column"); + + return errors; + } + + public static void addInsertUpdateErrors(ExpectedErrors errors) { + errors.addAll(getInsertUpdateErrors()); } } diff --git a/src/sqlancer/mysql/MySQLExpectedValueVisitor.java b/src/sqlancer/mysql/MySQLExpectedValueVisitor.java index fd77a9f85..8ad2d8b2b 100644 --- a/src/sqlancer/mysql/MySQLExpectedValueVisitor.java +++ b/src/sqlancer/mysql/MySQLExpectedValueVisitor.java @@ -1,10 +1,14 @@ package sqlancer.mysql; +import java.util.List; + import sqlancer.IgnoreMeException; +import sqlancer.mysql.ast.MySQLAggregate; import sqlancer.mysql.ast.MySQLBetweenOperation; import sqlancer.mysql.ast.MySQLBinaryComparisonOperation; import sqlancer.mysql.ast.MySQLBinaryLogicalOperation; import sqlancer.mysql.ast.MySQLBinaryOperation; +import sqlancer.mysql.ast.MySQLCaseOperator; import sqlancer.mysql.ast.MySQLCastOperation; import sqlancer.mysql.ast.MySQLCollate; import sqlancer.mysql.ast.MySQLColumnReference; @@ -13,10 +17,12 @@ import sqlancer.mysql.ast.MySQLExists; import sqlancer.mysql.ast.MySQLExpression; import sqlancer.mysql.ast.MySQLInOperation; +import sqlancer.mysql.ast.MySQLJoin; import sqlancer.mysql.ast.MySQLOrderByTerm; import sqlancer.mysql.ast.MySQLSelect; import sqlancer.mysql.ast.MySQLStringExpression; import sqlancer.mysql.ast.MySQLTableReference; +import sqlancer.mysql.ast.MySQLText; import sqlancer.mysql.ast.MySQLUnaryPostfixOperation; public class MySQLExpectedValueVisitor implements MySQLVisitor { @@ -153,4 +159,47 @@ public void visit(MySQLCollate collate) { visit(collate.getExpectedValue()); } + @Override + public void visit(MySQLJoin join) { + print(join); + visit(join.getOnClause()); + } + + @Override + public void visit(MySQLText text) { + print(text); + } + + @Override + public void visit(MySQLAggregate aggr) { + // PQS is currently unsupported for aggregates. + throw new IgnoreMeException(); + } + + @Override + public void visit(MySQLCaseOperator caseOp) { + print(caseOp); + + MySQLExpression switchCondition = caseOp.getSwitchCondition(); + if (switchCondition != null) { + print(switchCondition); + visit(switchCondition); + } + + List whenConditions = caseOp.getConditions(); + List thenExpressions = caseOp.getExpressions(); + + for (int i = 0; i < whenConditions.size(); i++) { + print(whenConditions.get(i)); + visit(whenConditions.get(i)); + print(thenExpressions.get(i)); + visit(thenExpressions.get(i)); + } + + MySQLExpression elseExpr = caseOp.getElseExpr(); + if (elseExpr != null) { + print(elseExpr); + visit(elseExpr); + } + } } diff --git a/src/sqlancer/mysql/MySQLGlobalState.java b/src/sqlancer/mysql/MySQLGlobalState.java index 86cf6dd18..10132b57c 100644 --- a/src/sqlancer/mysql/MySQLGlobalState.java +++ b/src/sqlancer/mysql/MySQLGlobalState.java @@ -4,7 +4,6 @@ import java.sql.SQLException; import sqlancer.SQLGlobalState; -import sqlancer.mysql.MySQLOptions.MySQLOracleFactory; public class MySQLGlobalState extends SQLGlobalState { diff --git a/src/sqlancer/mysql/MySQLOptions.java b/src/sqlancer/mysql/MySQLOptions.java index 2a3f16968..9219073d5 100644 --- a/src/sqlancer/mysql/MySQLOptions.java +++ b/src/sqlancer/mysql/MySQLOptions.java @@ -1,6 +1,5 @@ package sqlancer.mysql; -import java.sql.SQLException; import java.util.Arrays; import java.util.List; @@ -8,11 +7,6 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.TestOracle; -import sqlancer.mysql.MySQLOptions.MySQLOracleFactory; -import sqlancer.mysql.oracle.MySQLPivotedQuerySynthesisOracle; -import sqlancer.mysql.oracle.MySQLTLPWhereOracle; @Parameters(separators = "=", commandDescription = "MySQL (default port: " + MySQLOptions.DEFAULT_PORT + ", default host: " + MySQLOptions.DEFAULT_HOST + ")") @@ -23,31 +17,6 @@ public class MySQLOptions implements DBMSSpecificOptions { @Parameter(names = "--oracle") public List oracles = Arrays.asList(MySQLOracleFactory.TLP_WHERE); - public enum MySQLOracleFactory implements OracleFactory { - - TLP_WHERE { - - @Override - public TestOracle create(MySQLGlobalState globalState) throws SQLException { - return new MySQLTLPWhereOracle(globalState); - } - - }, - PQS { - - @Override - public TestOracle create(MySQLGlobalState globalState) throws SQLException { - return new MySQLPivotedQuerySynthesisOracle(globalState); - } - - @Override - public boolean requiresAllTablesToContainRows() { - return true; - } - - } - } - @Override public List getTestOracleFactory() { return oracles; diff --git a/src/sqlancer/mysql/MySQLOracleFactory.java b/src/sqlancer/mysql/MySQLOracleFactory.java new file mode 100644 index 000000000..83e08677a --- /dev/null +++ b/src/sqlancer/mysql/MySQLOracleFactory.java @@ -0,0 +1,86 @@ +package sqlancer.mysql; + +import java.sql.SQLException; +import java.util.Optional; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CERTOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.mysql.gen.MySQLExpressionGenerator; +import sqlancer.mysql.oracle.MySQLDQEOracle; +import sqlancer.mysql.oracle.MySQLDQPOracle; +import sqlancer.mysql.oracle.MySQLFuzzer; +import sqlancer.mysql.oracle.MySQLPivotedQuerySynthesisOracle; + +public enum MySQLOracleFactory implements OracleFactory { + + TLP_WHERE { + @Override + public TestOracle create(MySQLGlobalState globalState) throws SQLException { + MySQLExpressionGenerator gen = new MySQLExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(MySQLErrors.getExpressionErrors()) + .withRegex(MySQLErrors.getExpressionRegexErrors()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + + }, + PQS { + @Override + public TestOracle create(MySQLGlobalState globalState) throws SQLException { + return new MySQLPivotedQuerySynthesisOracle(globalState); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + + }, + CERT { + @Override + public TestOracle create(MySQLGlobalState globalState) throws SQLException { + MySQLExpressionGenerator gen = new MySQLExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(MySQLErrors.getExpressionErrors()) + .withRegex(MySQLErrors.getExpressionRegexErrors()).build(); + CERTOracle.CheckedFunction> rowCountParser = (rs) -> { + int rowCount = rs.getInt(10); + return Optional.of((long) rowCount); + }; + CERTOracle.CheckedFunction> queryPlanParser = (rs) -> { + String operation = rs.getString(2); + return Optional.of(operation); + }; + + return new CERTOracle<>(globalState, gen, expectedErrors, rowCountParser, queryPlanParser); + + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + }, + FUZZER { + @Override + public TestOracle create(MySQLGlobalState globalState) throws Exception { + return new MySQLFuzzer(globalState); + } + + }, + DQP { + @Override + public TestOracle create(MySQLGlobalState globalState) throws SQLException { + return new MySQLDQPOracle(globalState); + } + }, + DQE { + @Override + public TestOracle create(MySQLGlobalState globalState) throws SQLException { + return new MySQLDQEOracle(globalState); + } + }; +} diff --git a/src/sqlancer/mysql/MySQLProvider.java b/src/sqlancer/mysql/MySQLProvider.java index 0f1704bf3..80a23b947 100644 --- a/src/sqlancer/mysql/MySQLProvider.java +++ b/src/sqlancer/mysql/MySQLProvider.java @@ -4,6 +4,8 @@ import java.sql.DriverManager; import java.sql.SQLException; import java.sql.Statement; +import java.util.List; +import java.util.stream.Collectors; import com.google.auto.service.AutoService; @@ -16,8 +18,11 @@ import sqlancer.SQLProviderAdapter; import sqlancer.StatementExecutor; import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLQueryProvider; +import sqlancer.mysql.MySQLSchema.MySQLColumn; +import sqlancer.mysql.MySQLSchema.MySQLTable; import sqlancer.mysql.gen.MySQLAlterTable; import sqlancer.mysql.gen.MySQLDeleteGenerator; import sqlancer.mysql.gen.MySQLDropIndex; @@ -25,6 +30,7 @@ import sqlancer.mysql.gen.MySQLSetGenerator; import sqlancer.mysql.gen.MySQLTableGenerator; import sqlancer.mysql.gen.MySQLTruncateTableGenerator; +import sqlancer.mysql.gen.MySQLUpdateGenerator; import sqlancer.mysql.gen.admin.MySQLFlush; import sqlancer.mysql.gen.admin.MySQLReset; import sqlancer.mysql.gen.datadef.MySQLIndexGenerator; @@ -56,11 +62,7 @@ enum Action implements AbstractAction { SELECT_INFO((g) -> new SQLQueryAdapter( "select TABLE_NAME, ENGINE from information_schema.TABLES where table_schema = '" + g.getDatabaseName() + "'")), // - CREATE_TABLE((g) -> { - // TODO refactor - String tableName = DBMSCommon.createTableName(g.getSchema().getDatabaseTables().size()); - return MySQLTableGenerator.generate(g, tableName); - }), // + UPDATE(MySQLUpdateGenerator::create), // DELETE(MySQLDeleteGenerator::delete), // DROP_INDEX(MySQLDropIndex::generate); @@ -86,9 +88,6 @@ private static int mapActions(MySQLGlobalState globalState, Action a) { case SHOW_TABLES: nrPerformed = r.getInteger(0, 1); break; - case CREATE_TABLE: - nrPerformed = r.getInteger(0, 1); - break; case INSERT: nrPerformed = r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); break; @@ -126,6 +125,9 @@ private static int mapActions(MySQLGlobalState globalState, Action a) { case SELECT_INFO: nrPerformed = r.getInteger(0, 10); break; + case UPDATE: + nrPerformed = r.getInteger(0, 10); + break; case DELETE: nrPerformed = r.getInteger(0, 10); break; @@ -137,7 +139,7 @@ private static int mapActions(MySQLGlobalState globalState, Action a) { @Override public void generateDatabase(MySQLGlobalState globalState) throws Exception { - while (globalState.getSchema().getDatabaseTables().size() < Randomly.smallNumber() + 1) { + while (globalState.getSchema().getDatabaseTables().size() < Randomly.getNotCachedInteger(1, 2)) { String tableName = DBMSCommon.createTableName(globalState.getSchema().getDatabaseTables().size()); SQLQueryAdapter createTable = MySQLTableGenerator.generate(globalState, tableName); globalState.executeStatement(createTable); @@ -150,6 +152,23 @@ public void generateDatabase(MySQLGlobalState globalState) throws Exception { } }); se.executeStatements(); + + if (globalState.getDbmsSpecificOptions().getTestOracleFactory().stream() + .anyMatch((o) -> o == MySQLOracleFactory.CERT)) { + // Enfore statistic collected for all tables + ExpectedErrors errors = new ExpectedErrors(); + MySQLErrors.addExpressionErrors(errors); + for (MySQLTable table : globalState.getSchema().getDatabaseTables()) { + StringBuilder sb = new StringBuilder(); + sb.append("ANALYZE TABLE "); + sb.append(table.getName()); + sb.append(" UPDATE HISTOGRAM ON "); + String columns = table.getColumns().stream().map(MySQLColumn::getName) + .collect(Collectors.joining(", ")); + sb.append(columns + ";"); + globalState.executeStatement(new SQLQueryAdapter(sb.toString(), errors)); + } + } } @Override @@ -188,4 +207,15 @@ public String getDBMSName() { return "mysql"; } + @Override + public boolean addRowsToAllTables(MySQLGlobalState globalState) throws Exception { + List tablesNoRow = globalState.getSchema().getDatabaseTables().stream() + .filter(t -> t.getNrRows(globalState) == 0).collect(Collectors.toList()); + for (MySQLTable table : tablesNoRow) { + SQLQueryAdapter queryAddRows = MySQLInsertGenerator.insertRow(globalState, table); + globalState.executeStatement(queryAddRows); + } + return true; + } + } diff --git a/src/sqlancer/mysql/MySQLSchema.java b/src/sqlancer/mysql/MySQLSchema.java index 0384f34df..c8a30614f 100644 --- a/src/sqlancer/mysql/MySQLSchema.java +++ b/src/sqlancer/mysql/MySQLSchema.java @@ -76,6 +76,7 @@ public int getPrecision() { return precision; } + @Override public boolean isPrimaryKey() { return isPrimaryKey; } @@ -194,10 +195,6 @@ public MySQLEngine getEngine() { return engine; } - public boolean hasPrimaryKey() { - return getColumns().stream().anyMatch(c -> c.isPrimaryKey()); - } - } public static final class MySQLIndex extends TableIndex { diff --git a/src/sqlancer/mysql/MySQLToStringVisitor.java b/src/sqlancer/mysql/MySQLToStringVisitor.java index 06f13b008..be82d45b5 100644 --- a/src/sqlancer/mysql/MySQLToStringVisitor.java +++ b/src/sqlancer/mysql/MySQLToStringVisitor.java @@ -5,10 +5,13 @@ import sqlancer.Randomly; import sqlancer.common.visitor.ToStringVisitor; +import sqlancer.mysql.ast.MySQLAggregate; +import sqlancer.mysql.ast.MySQLAggregate.MySQLAggregateFunction; import sqlancer.mysql.ast.MySQLBetweenOperation; import sqlancer.mysql.ast.MySQLBinaryComparisonOperation; import sqlancer.mysql.ast.MySQLBinaryLogicalOperation; import sqlancer.mysql.ast.MySQLBinaryOperation; +import sqlancer.mysql.ast.MySQLCaseOperator; import sqlancer.mysql.ast.MySQLCastOperation; import sqlancer.mysql.ast.MySQLCollate; import sqlancer.mysql.ast.MySQLColumnReference; @@ -17,11 +20,13 @@ import sqlancer.mysql.ast.MySQLExists; import sqlancer.mysql.ast.MySQLExpression; import sqlancer.mysql.ast.MySQLInOperation; +import sqlancer.mysql.ast.MySQLJoin; import sqlancer.mysql.ast.MySQLOrderByTerm; import sqlancer.mysql.ast.MySQLOrderByTerm.MySQLOrder; import sqlancer.mysql.ast.MySQLSelect; import sqlancer.mysql.ast.MySQLStringExpression; import sqlancer.mysql.ast.MySQLTableReference; +import sqlancer.mysql.ast.MySQLText; import sqlancer.mysql.ast.MySQLUnaryPostfixOperation; public class MySQLToStringVisitor extends ToStringVisitor implements MySQLVisitor { @@ -36,6 +41,11 @@ public void visitSpecific(MySQLExpression expr) { @Override public void visit(MySQLSelect s) { sb.append("SELECT "); + if (s.getHint() != null) { + sb.append("/*+ "); + visit(s.getHint()); + sb.append("*/ "); + } switch (s.getFromOptions()) { case DISTINCT: sb.append("DISTINCT "); @@ -50,7 +60,7 @@ public void visit(MySQLSelect s) { throw new AssertionError(); } sb.append(s.getModifiers().stream().collect(Collectors.joining(" "))); - if (s.getModifiers().size() > 0) { + if (!s.getModifiers().isEmpty()) { sb.append(" "); } if (s.getFetchColumns() == null) { @@ -83,7 +93,7 @@ public void visit(MySQLSelect s) { sb.append(" WHERE "); visit(whereClause); } - if (s.getGroupByExpressions() != null && s.getGroupByExpressions().size() > 0) { + if (s.getGroupByExpressions() != null && !s.getGroupByExpressions().isEmpty()) { sb.append(" "); sb.append("GROUP BY "); List groupBys = s.getGroupByExpressions(); @@ -94,14 +104,14 @@ public void visit(MySQLSelect s) { visit(groupBys.get(i)); } } - if (!s.getOrderByExpressions().isEmpty()) { + if (!s.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - List orderBys = s.getOrderByExpressions(); + List orderBys = s.getOrderByClauses(); for (int i = 0; i < orderBys.size(); i++) { if (i != 0) { sb.append(", "); } - visit(s.getOrderByExpressions().get(i)); + visit(s.getOrderByClauses().get(i)); } } if (s.getLimitClause() != null) { @@ -278,4 +288,88 @@ public void visit(MySQLCollate collate) { sb.append(")"); } + @Override + public void visit(MySQLJoin join) { + sb.append(" "); + switch (join.getType()) { + case NATURAL: + sb.append("NATURAL "); + break; + case INNER: + sb.append("INNER "); + break; + case STRAIGHT: + sb.append("STRAIGHT_"); + break; + case LEFT: + sb.append("LEFT "); + break; + case RIGHT: + sb.append("RIGHT "); + break; + case CROSS: + sb.append("CROSS "); + break; + default: + throw new AssertionError(join.getType()); + } + sb.append("JOIN "); + sb.append(join.getTable().getName()); + if (join.getOnClause() != null) { + sb.append(" ON "); + visit(join.getOnClause()); + } + } + + @Override + public void visit(MySQLText text) { + sb.append(text.getText()); + } + + @Override + public void visit(MySQLAggregate aggr) { + MySQLAggregateFunction func = aggr.getFunc(); + String option = func.getOption(); + List exprs = aggr.getExprs(); + + sb.append(func.getName()); + sb.append("("); + if (option != null) { + sb.append(option); + sb.append(" "); + } + for (int i = 0; i < exprs.size(); i++) { + if (i != 0) { + sb.append(", "); + } + visit(exprs.get(i)); + } + sb.append(")"); + } + + @Override + public void visit(MySQLCaseOperator caseOp) { + sb.append("(CASE "); + + if (caseOp.getSwitchCondition() != null) { + visit(caseOp.getSwitchCondition()); + sb.append(" "); + } + + for (int i = 0; i < caseOp.getConditions().size(); i++) { + if (i > 0) { + sb.append(" "); + } + sb.append("WHEN "); + visit(caseOp.getConditions().get(i)); + sb.append(" THEN "); + visit(caseOp.getExpressions().get(i)); + } + + if (caseOp.getElseExpr() != null) { + sb.append(" ELSE "); + visit(caseOp.getElseExpr()); + } + sb.append(" END)"); + } } diff --git a/src/sqlancer/mysql/MySQLVisitor.java b/src/sqlancer/mysql/MySQLVisitor.java index eeda4f681..12c93ecfc 100644 --- a/src/sqlancer/mysql/MySQLVisitor.java +++ b/src/sqlancer/mysql/MySQLVisitor.java @@ -1,9 +1,11 @@ package sqlancer.mysql; +import sqlancer.mysql.ast.MySQLAggregate; import sqlancer.mysql.ast.MySQLBetweenOperation; import sqlancer.mysql.ast.MySQLBinaryComparisonOperation; import sqlancer.mysql.ast.MySQLBinaryLogicalOperation; import sqlancer.mysql.ast.MySQLBinaryOperation; +import sqlancer.mysql.ast.MySQLCaseOperator; import sqlancer.mysql.ast.MySQLCastOperation; import sqlancer.mysql.ast.MySQLCollate; import sqlancer.mysql.ast.MySQLColumnReference; @@ -12,10 +14,12 @@ import sqlancer.mysql.ast.MySQLExists; import sqlancer.mysql.ast.MySQLExpression; import sqlancer.mysql.ast.MySQLInOperation; +import sqlancer.mysql.ast.MySQLJoin; import sqlancer.mysql.ast.MySQLOrderByTerm; import sqlancer.mysql.ast.MySQLSelect; import sqlancer.mysql.ast.MySQLStringExpression; import sqlancer.mysql.ast.MySQLTableReference; +import sqlancer.mysql.ast.MySQLText; import sqlancer.mysql.ast.MySQLUnaryPostfixOperation; public interface MySQLVisitor { @@ -52,6 +56,14 @@ public interface MySQLVisitor { void visit(MySQLCollate collate); + void visit(MySQLJoin join); + + void visit(MySQLText text); + + void visit(MySQLAggregate aggregate); + + void visit(MySQLCaseOperator caseOp); + default void visit(MySQLExpression expr) { if (expr instanceof MySQLConstant) { visit((MySQLConstant) expr); @@ -77,6 +89,8 @@ default void visit(MySQLExpression expr) { visit((MySQLOrderByTerm) expr); } else if (expr instanceof MySQLExists) { visit((MySQLExists) expr); + } else if (expr instanceof MySQLJoin) { + visit((MySQLJoin) expr); } else if (expr instanceof MySQLStringExpression) { visit((MySQLStringExpression) expr); } else if (expr instanceof MySQLBetweenOperation) { @@ -85,6 +99,12 @@ default void visit(MySQLExpression expr) { visit((MySQLTableReference) expr); } else if (expr instanceof MySQLCollate) { visit((MySQLCollate) expr); + } else if (expr instanceof MySQLText) { + visit((MySQLText) expr); + } else if (expr instanceof MySQLAggregate) { + visit((MySQLAggregate) expr); + } else if (expr instanceof MySQLCaseOperator) { + visit((MySQLCaseOperator) expr); } else { throw new AssertionError(expr); } diff --git a/src/sqlancer/mysql/ast/MySQLAggregate.java b/src/sqlancer/mysql/ast/MySQLAggregate.java new file mode 100644 index 000000000..94c4e426b --- /dev/null +++ b/src/sqlancer/mysql/ast/MySQLAggregate.java @@ -0,0 +1,55 @@ +package sqlancer.mysql.ast; + +import java.util.List; + +public class MySQLAggregate implements MySQLExpression { + + public enum MySQLAggregateFunction { + // See https://dev.mysql.com/doc/refman/8.4/en/aggregate-functions.html#function_count. + COUNT("COUNT", null, false), COUNT_DISTINCT("COUNT", "DISTINCT", true), + // See https://dev.mysql.com/doc/refman/8.4/en/aggregate-functions.html#function_sum. + SUM("SUM", null, false), SUM_DISTINCT("SUM", "DISTINCT", false), + // See https://dev.mysql.com/doc/refman/8.4/en/aggregate-functions.html#function_min. + MIN("MIN", null, false), MIN_DISTINCT("MIN", "DISTINCT", false), + // See https://dev.mysql.com/doc/refman/8.4/en/aggregate-functions.html#function_max. + MAX("MAX", null, false), MAX_DISTINCT("MAX", "DISTINCT", false); + + private final String name; + private final String option; + private final boolean isVariadic; + + MySQLAggregateFunction(String name, String option, boolean isVariadic) { + this.name = name; + this.option = option; + this.isVariadic = isVariadic; + } + + public String getName() { + return this.name; + } + + public String getOption() { + return option; + } + + public boolean isVariadic() { + return this.isVariadic; + } + } + + private final List exprs; + private final MySQLAggregateFunction func; + + public MySQLAggregate(List exprs, MySQLAggregateFunction func) { + this.exprs = exprs; + this.func = func; + } + + public List getExprs() { + return exprs; + } + + public MySQLAggregateFunction getFunc() { + return func; + } +} diff --git a/src/sqlancer/mysql/ast/MySQLBinaryOperation.java b/src/sqlancer/mysql/ast/MySQLBinaryOperation.java index 33293f26b..af0b0ccf7 100644 --- a/src/sqlancer/mysql/ast/MySQLBinaryOperation.java +++ b/src/sqlancer/mysql/ast/MySQLBinaryOperation.java @@ -78,20 +78,20 @@ public MySQLConstant getExpectedValue() { /* workaround for https://bugs.mysql.com/bug.php?id=95960 */ if (leftExpected.isString()) { String text = leftExpected.castAsString(); - while ((text.startsWith(" ") || text.startsWith("\t")) && text.length() > 0) { + while (text.startsWith(" ") || text.startsWith("\t")) { text = text.substring(1); } - if (text.length() > 0 && (text.startsWith("\n") || text.startsWith("."))) { + if (text.startsWith("\n") || text.startsWith(".")) { throw new IgnoreMeException(); } } if (rightExpected.isString()) { String text = rightExpected.castAsString(); - while ((text.startsWith(" ") || text.startsWith("\t")) && text.length() > 0) { + while (text.startsWith(" ") || text.startsWith("\t")) { text = text.substring(1); } - if (text.length() > 0 && (text.startsWith("\n") || text.startsWith("."))) { + if (text.startsWith("\n") || text.startsWith(".")) { throw new IgnoreMeException(); } } diff --git a/src/sqlancer/mysql/ast/MySQLCaseOperator.java b/src/sqlancer/mysql/ast/MySQLCaseOperator.java new file mode 100644 index 000000000..ceb7fbbaa --- /dev/null +++ b/src/sqlancer/mysql/ast/MySQLCaseOperator.java @@ -0,0 +1,48 @@ +package sqlancer.mysql.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewCaseOperatorNode; + +public class MySQLCaseOperator extends NewCaseOperatorNode implements MySQLExpression { + + public MySQLCaseOperator(MySQLExpression switchCondition, List whenExprs, + List thenExprs, MySQLExpression elseExpr) { + super(switchCondition, whenExprs, thenExprs, elseExpr); + } + + @Override + public MySQLConstant getExpectedValue() { + int nrConditions = getConditions().size(); + + MySQLExpression switchCondition = getSwitchCondition(); + List whenExprs = getConditions(); + List thenExprs = getExpressions(); + MySQLExpression elseExpr = getElseExpr(); + + if (switchCondition != null) { + MySQLConstant switchValue = switchCondition.getExpectedValue(); + + for (int i = 0; i < nrConditions; i++) { + MySQLConstant whenValue = whenExprs.get(i).getExpectedValue(); + MySQLConstant isConditionMatched = switchValue.isEquals(whenValue); + if (!isConditionMatched.isNull() && isConditionMatched.asBooleanNotNull()) { + return thenExprs.get(i).getExpectedValue(); + } + } + } else { + for (int i = 0; i < nrConditions; i++) { + MySQLConstant whenValue = whenExprs.get(i).getExpectedValue(); + if (!whenValue.isNull() && whenValue.asBooleanNotNull()) { + return thenExprs.get(i).getExpectedValue(); + } + } + } + + if (elseExpr != null) { + return elseExpr.getExpectedValue(); + } + + return MySQLConstant.createNullConstant(); + } +} diff --git a/src/sqlancer/mysql/ast/MySQLExpression.java b/src/sqlancer/mysql/ast/MySQLExpression.java index 61a3b8aeb..1f3ae5bcb 100644 --- a/src/sqlancer/mysql/ast/MySQLExpression.java +++ b/src/sqlancer/mysql/ast/MySQLExpression.java @@ -1,6 +1,9 @@ package sqlancer.mysql.ast; -public interface MySQLExpression { +import sqlancer.common.ast.newast.Expression; +import sqlancer.mysql.MySQLSchema.MySQLColumn; + +public interface MySQLExpression extends Expression { default MySQLConstant getExpectedValue() { throw new AssertionError("PQS not supported for this operator"); diff --git a/src/sqlancer/mysql/ast/MySQLJoin.java b/src/sqlancer/mysql/ast/MySQLJoin.java index 8558e43a9..c063b4f1b 100644 --- a/src/sqlancer/mysql/ast/MySQLJoin.java +++ b/src/sqlancer/mysql/ast/MySQLJoin.java @@ -1,10 +1,87 @@ package sqlancer.mysql.ast; -public class MySQLJoin implements MySQLExpression { +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; +import sqlancer.mysql.MySQLGlobalState; +import sqlancer.mysql.MySQLSchema.MySQLColumn; +import sqlancer.mysql.MySQLSchema.MySQLTable; +import sqlancer.mysql.gen.MySQLExpressionGenerator; + +public class MySQLJoin implements MySQLExpression, Join { + + public enum JoinType { + NATURAL, INNER, STRAIGHT, LEFT, RIGHT, CROSS; + } + + private final MySQLTable table; + private MySQLExpression onClause; + private JoinType type; + + public MySQLJoin(MySQLJoin other) { + this.table = other.table; + this.onClause = other.onClause; + this.type = other.type; + } + + public MySQLJoin(MySQLTable table, MySQLExpression onClause, JoinType type) { + this.table = table; + this.onClause = onClause; + this.type = type; + } + + public MySQLTable getTable() { + return table; + } + + public MySQLExpression getOnClause() { + return onClause; + } + + public JoinType getType() { + return type; + } @Override - public MySQLConstant getExpectedValue() { - throw new UnsupportedOperationException(); + public void setOnClause(MySQLExpression onClause) { + this.onClause = onClause; } + public void setType(JoinType type) { + this.type = type; + } + + public static List getRandomJoinClauses(List tables, MySQLGlobalState globalState) { + List joinStatements = new ArrayList<>(); + List options = new ArrayList<>(Arrays.asList(JoinType.values())); + List columns = new ArrayList<>(); + if (tables.size() > 1) { + int nrJoinClauses = (int) Randomly.getNotCachedInteger(0, tables.size()); + // Natural join is incompatible with other joins + // because it needs unique column names + // while other joins will produce duplicate column names + if (nrJoinClauses > 1) { + options.remove(JoinType.NATURAL); + } + for (int i = 0; i < nrJoinClauses; i++) { + MySQLTable table = Randomly.fromList(tables); + tables.remove(table); + columns.addAll(table.getColumns()); + MySQLExpressionGenerator joinGen = new MySQLExpressionGenerator(globalState).setColumns(columns); + MySQLExpression joinClause = joinGen.generateExpression(); + JoinType selectedOption = Randomly.fromList(options); + if (selectedOption == JoinType.NATURAL) { + // NATURAL joins do not have an ON clause + joinClause = null; + } + MySQLJoin j = new MySQLJoin(table, joinClause, selectedOption); + joinStatements.add(j); + } + + } + return joinStatements; + } } diff --git a/src/sqlancer/mysql/ast/MySQLSelect.java b/src/sqlancer/mysql/ast/MySQLSelect.java index 7abd1f639..7b9243c20 100644 --- a/src/sqlancer/mysql/ast/MySQLSelect.java +++ b/src/sqlancer/mysql/ast/MySQLSelect.java @@ -2,13 +2,20 @@ import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.mysql.MySQLSchema.MySQLColumn; +import sqlancer.mysql.MySQLSchema.MySQLTable; +import sqlancer.mysql.MySQLVisitor; -public class MySQLSelect extends SelectBase implements MySQLExpression { +public class MySQLSelect extends SelectBase + implements MySQLExpression, Select { private SelectType fromOptions = SelectType.ALL; private List modifiers = Collections.emptyList(); + private MySQLText hint; public enum SelectType { DISTINCT, ALL, DISTINCTROW; @@ -39,4 +46,28 @@ public MySQLConstant getExpectedValue() { return null; } + public void setHint(MySQLText hint) { + this.hint = hint; + } + + public MySQLText getHint() { + return hint; + } + + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (MySQLExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (MySQLJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return MySQLVisitor.asString(this); + } } diff --git a/src/sqlancer/mysql/ast/MySQLText.java b/src/sqlancer/mysql/ast/MySQLText.java new file mode 100644 index 000000000..36040a383 --- /dev/null +++ b/src/sqlancer/mysql/ast/MySQLText.java @@ -0,0 +1,14 @@ +package sqlancer.mysql.ast; + +public class MySQLText implements MySQLExpression { + + private final String text; + + public MySQLText(String text) { + this.text = text; + } + + public String getText() { + return text; + } +} diff --git a/src/sqlancer/mysql/gen/MySQLAlterTable.java b/src/sqlancer/mysql/gen/MySQLAlterTable.java index cb5db6bde..f2c952016 100644 --- a/src/sqlancer/mysql/gen/MySQLAlterTable.java +++ b/src/sqlancer/mysql/gen/MySQLAlterTable.java @@ -130,15 +130,15 @@ private SQLQueryAdapter create() { break; case STATS_AUTO_RECALC: sb.append("STATS_AUTO_RECALC "); - sb.append(Randomly.fromOptions(0, 1, "DEFAULT")); + sb.append(Randomly.fromOptions("0", "1", "DEFAULT")); break; case STATS_PERSISTENT: sb.append("STATS_PERSISTENT "); - sb.append(Randomly.fromOptions(0, 1, "DEFAULT")); + sb.append(Randomly.fromOptions("0", "1", "DEFAULT")); break; case PACK_KEYS: sb.append("PACK_KEYS "); - sb.append(Randomly.fromOptions(0, 1, "DEFAULT")); + sb.append(Randomly.fromOptions("0", "1", "DEFAULT")); break; // not relevant: // case WITH_WITHOUT_VALIDATION: diff --git a/src/sqlancer/mysql/gen/MySQLExpressionGenerator.java b/src/sqlancer/mysql/gen/MySQLExpressionGenerator.java index 7911a0f3e..98641ab26 100644 --- a/src/sqlancer/mysql/gen/MySQLExpressionGenerator.java +++ b/src/sqlancer/mysql/gen/MySQLExpressionGenerator.java @@ -2,14 +2,23 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import sqlancer.IgnoreMeException; import sqlancer.Randomly; +import sqlancer.common.gen.CERTGenerator; +import sqlancer.common.gen.TLPWhereGenerator; import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; import sqlancer.mysql.MySQLBugs; import sqlancer.mysql.MySQLGlobalState; import sqlancer.mysql.MySQLSchema.MySQLColumn; import sqlancer.mysql.MySQLSchema.MySQLRowValue; +import sqlancer.mysql.MySQLSchema.MySQLTable; +import sqlancer.mysql.ast.MySQLAggregate; +import sqlancer.mysql.ast.MySQLAggregate.MySQLAggregateFunction; import sqlancer.mysql.ast.MySQLBetweenOperation; import sqlancer.mysql.ast.MySQLBinaryComparisonOperation; import sqlancer.mysql.ast.MySQLBinaryComparisonOperation.BinaryComparisonOperator; @@ -17,6 +26,7 @@ import sqlancer.mysql.ast.MySQLBinaryLogicalOperation.MySQLBinaryLogicalOperator; import sqlancer.mysql.ast.MySQLBinaryOperation; import sqlancer.mysql.ast.MySQLBinaryOperation.MySQLBinaryOperator; +import sqlancer.mysql.ast.MySQLCaseOperator; import sqlancer.mysql.ast.MySQLCastOperation; import sqlancer.mysql.ast.MySQLColumnReference; import sqlancer.mysql.ast.MySQLComputableFunction; @@ -26,17 +36,23 @@ import sqlancer.mysql.ast.MySQLExists; import sqlancer.mysql.ast.MySQLExpression; import sqlancer.mysql.ast.MySQLInOperation; +import sqlancer.mysql.ast.MySQLJoin; import sqlancer.mysql.ast.MySQLOrderByTerm; import sqlancer.mysql.ast.MySQLOrderByTerm.MySQLOrder; +import sqlancer.mysql.ast.MySQLSelect; import sqlancer.mysql.ast.MySQLStringExpression; +import sqlancer.mysql.ast.MySQLTableReference; import sqlancer.mysql.ast.MySQLUnaryPostfixOperation; import sqlancer.mysql.ast.MySQLUnaryPrefixOperation; import sqlancer.mysql.ast.MySQLUnaryPrefixOperation.MySQLUnaryPrefixOperator; -public class MySQLExpressionGenerator extends UntypedExpressionGenerator { +public class MySQLExpressionGenerator extends UntypedExpressionGenerator + implements TLPWhereGenerator, + CERTGenerator { private final MySQLGlobalState state; private MySQLRowValue rowVal; + private List tables; public MySQLExpressionGenerator(MySQLGlobalState state) { this.state = state; @@ -49,7 +65,7 @@ public MySQLExpressionGenerator setRowVal(MySQLRowValue rowVal) { private enum Actions { COLUMN, LITERAL, UNARY_PREFIX_OPERATION, UNARY_POSTFIX, COMPUTABLE_FUNCTION, BINARY_LOGICAL_OPERATOR, - BINARY_COMPARISON_OPERATION, CAST, IN_OPERATION, BINARY_OPERATION, EXISTS, BETWEEN_OPERATOR; + BINARY_COMPARISON_OPERATION, CAST, IN_OPERATION, BINARY_OPERATION, EXISTS, BETWEEN_OPERATOR, CASE_OPERATOR; } @Override @@ -65,10 +81,6 @@ public MySQLExpression generateExpression(int depth) { case UNARY_PREFIX_OPERATION: MySQLExpression subExpr = generateExpression(depth + 1); MySQLUnaryPrefixOperator random = MySQLUnaryPrefixOperator.getRandom(); - if (random == MySQLUnaryPrefixOperator.MINUS) { - // workaround for https://bugs.mysql.com/bug.php?id=99122 - throw new IgnoreMeException(); - } return new MySQLUnaryPrefixOperation(subExpr, random); case UNARY_POSTFIX: return new MySQLUnaryPostfixOperation(generateExpression(depth + 1), @@ -106,6 +118,10 @@ public MySQLExpression generateExpression(int depth) { } return new MySQLBetweenOperation(generateExpression(depth + 1), generateExpression(depth + 1), generateExpression(depth + 1)); + case CASE_OPERATOR: + int nr = Randomly.smallNumber() + 1; + return new MySQLCaseOperator(generateExpression(depth + 1), generateExpressions(nr, depth + 1), + generateExpressions(nr, depth + 1), generateExpression(depth + 1)); default: throw new AssertionError(); } @@ -156,34 +172,9 @@ public MySQLExpression generateConstant() { case STRING: /* Replace characters that still trigger open bugs in MySQL */ String string = state.getRandomly().getString().replace("\\", "").replace("\n", ""); - if (string.startsWith("\n")) { - // workaround for https://bugs.mysql.com/bug.php?id=99130 - throw new IgnoreMeException(); - } - if (string.startsWith("-0") || string.startsWith("0.") || string.startsWith(".")) { - // https://bugs.mysql.com/bug.php?id=99145 - throw new IgnoreMeException(); - } - MySQLConstant createStringConstant = MySQLConstant.createStringConstant(string); - // if (Randomly.getBoolean()) { - // return new MySQLCollate(createStringConstant, - // Randomly.fromOptions("ascii_bin", "binary")); - // } - if (string.startsWith("1e")) { - // https://bugs.mysql.com/bug.php?id=99146 - throw new IgnoreMeException(); - } - return createStringConstant; + return MySQLConstant.createStringConstant(string); case DOUBLE: double val = state.getRandomly().getDouble(); - if (Math.abs(val) <= 1 && val != 0) { - // https://bugs.mysql.com/bug.php?id=99145 - throw new IgnoreMeException(); - } - if (Math.abs(val) > 1.0E30) { - // https://bugs.mysql.com/bug.php?id=99146 - throw new IgnoreMeException(); - } return new MySQLDoubleConstant(val); default: throw new AssertionError(); @@ -227,4 +218,139 @@ public List generateOrderBys() { return newOrderBys; } + @Override + public MySQLExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public MySQLExpression generateBooleanExpression() { + return generateExpression(); + } + + @Override + public MySQLSelect generateSelect() { + return new MySQLSelect(); + } + + @Override + public List getRandomJoinClauses() { + return List.of(); + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new MySQLTableReference(t)).collect(Collectors.toList()); + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + return columns.stream().map(c -> new MySQLColumnReference(c, null)).collect(Collectors.toList()); + } + + @Override + public String generateExplainQuery(MySQLSelect select) { + return "EXPLAIN " + select.asString(); + } + + public MySQLAggregate generateAggregate() { + MySQLAggregateFunction func = Randomly.fromOptions(MySQLAggregateFunction.values()); + + if (func.isVariadic()) { + int nrExprs = Randomly.smallNumber() + 1; + List exprs = IntStream.range(0, nrExprs).mapToObj(index -> generateExpression()) + .collect(Collectors.toList()); + + return new MySQLAggregate(exprs, func); + } else { + return new MySQLAggregate(List.of(generateExpression()), func); + } + } + + @Override + public boolean mutate(MySQLSelect select) { + List> mutators = new ArrayList<>(); + + mutators.add(this::mutateWhere); + mutators.add(this::mutateGroupBy); + mutators.add(this::mutateHaving); + mutators.add(this::mutateAnd); + mutators.add(this::mutateOr); + mutators.add(this::mutateDistinct); + + return Randomly.fromList(mutators).apply(select); + } + + boolean mutateDistinct(MySQLSelect select) { + MySQLSelect.SelectType selectType = select.getFromOptions(); + if (selectType != MySQLSelect.SelectType.ALL) { + select.setSelectType(MySQLSelect.SelectType.ALL); + return true; + } else { + select.setSelectType(MySQLSelect.SelectType.DISTINCT); + return false; + } + } + + boolean mutateWhere(MySQLSelect select) { + boolean increase = select.getWhereClause() != null; + if (increase) { + select.setWhereClause(null); + } else { + select.setWhereClause(generateExpression()); + } + return increase; + } + + boolean mutateGroupBy(MySQLSelect select) { + boolean increase = !select.getGroupByExpressions().isEmpty(); + if (increase) { + select.clearGroupByExpressions(); + } else { + select.setGroupByExpressions(select.getFetchColumns()); + } + return increase; + } + + boolean mutateHaving(MySQLSelect select) { + if (select.getGroupByExpressions().isEmpty()) { + select.setGroupByExpressions(select.getFetchColumns()); + select.setHavingClause(generateExpression()); + return false; + } else { + if (select.getHavingClause() == null) { + select.setHavingClause(generateExpression()); + return false; + } else { + select.setHavingClause(null); + return true; + } + } + } + + boolean mutateAnd(MySQLSelect select) { + if (select.getWhereClause() == null) { + select.setWhereClause(generateExpression()); + } else { + MySQLExpression newWhere = new MySQLBinaryLogicalOperation(select.getWhereClause(), generateExpression(), + MySQLBinaryLogicalOperator.AND); + select.setWhereClause(newWhere); + } + return false; + } + + boolean mutateOr(MySQLSelect select) { + if (select.getWhereClause() == null) { + select.setWhereClause(generateExpression()); + return false; + } else { + MySQLExpression newWhere = new MySQLBinaryLogicalOperation(select.getWhereClause(), generateExpression(), + MySQLBinaryLogicalOperator.OR); + select.setWhereClause(newWhere); + return true; + } + } } diff --git a/src/sqlancer/mysql/gen/MySQLHintGenerator.java b/src/sqlancer/mysql/gen/MySQLHintGenerator.java new file mode 100644 index 000000000..141aea279 --- /dev/null +++ b/src/sqlancer/mysql/gen/MySQLHintGenerator.java @@ -0,0 +1,204 @@ +package sqlancer.mysql.gen; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.mysql.MySQLSchema.MySQLIndex; +import sqlancer.mysql.MySQLSchema.MySQLTable; +import sqlancer.mysql.ast.MySQLSelect; +import sqlancer.mysql.ast.MySQLText; + +public class MySQLHintGenerator { + + private final MySQLSelect select; + private final List tables; + private final StringBuilder sb = new StringBuilder(); + + enum OptimizeHint { + BKA, NO_BKA, BNL, NO_BNL, DERIVED_CONDITION_PUSHDOWN, NO_DERIVED_CONDITION_PUSHDOWN, GROUP_INDEX, + NO_GROUP_INDEX, HASH_JOIN, NO_HASH_JOIN, INDEX, NO_INDEX, INDEX_MERGE, NO_INDEX_MERGE, JOIN_FIXED_ORDER, + JOIN_INDEX, NO_JOIN_INDEX, JOIN_ORDER, JOIN_PREFIX, JOIN_SUFFIX, MERGE, NO_MERGE, MRR, NO_MRR, NO_ICP, + NO_RANGE_OPTIMIZATION, ORDER_INDEX, NO_ORDER_INDEX, SEMIJOIN, NO_SEMIJOIN, SKIP_SCAN, NO_SKIP_SCAN + } + + public MySQLHintGenerator(MySQLSelect select, List tables) { + this.select = select; + this.tables = tables; + } + + public static void generateHints(MySQLSelect select, List tables) { + new MySQLHintGenerator(select, tables).randomHint(); + } + + public static List generateAllHints(MySQLSelect select, List tables) { + MySQLHintGenerator generator = new MySQLHintGenerator(select, tables); + return generator.allHints(); + } + + private void randomHint() { + OptimizeHint chosenhint = Randomly.fromOptions(OptimizeHint.values()); + generate(chosenhint); + } + + private List allHints() { + List results = new ArrayList<>(); + for (OptimizeHint hint : OptimizeHint.values()) { + try { + MySQLText generatedHint = generate(hint); + results.add(generatedHint); + } catch (IgnoreMeException e) { + continue; + } + } + return results; + } + + private MySQLText generate(OptimizeHint chosenhint) { + sb.setLength(0); + + switch (chosenhint) { + case BKA: + tablesHint("BKA"); + break; + case NO_BKA: + tablesHint("NO_BKA"); + break; + case BNL: + tablesHint("BNL"); + break; + case NO_BNL: + tablesHint("NO_BNL"); + break; + case DERIVED_CONDITION_PUSHDOWN: + tablesHint("DERIVED_CONDITION_PUSHDOWN"); + break; + case NO_DERIVED_CONDITION_PUSHDOWN: + tablesHint("NO_DERIVED_CONDITION_PUSHDOWN"); + break; + case GROUP_INDEX: + indexesHint("GROUP_INDEX"); + break; + case NO_GROUP_INDEX: + indexesHint("NO_GROUP_INDEX"); + break; + case HASH_JOIN: + tablesHint("HASH_JOIN"); + break; + case NO_HASH_JOIN: + tablesHint("NO_HASH_JOIN"); + break; + case INDEX: + indexesHint("INDEX"); + break; + case NO_INDEX: + indexesHint("NO_INDEX"); + break; + case INDEX_MERGE: + indexesHint("INDEX_MERGE"); + break; + case NO_INDEX_MERGE: + indexesHint("NO_INDEX_MERGE"); + break; + case JOIN_FIXED_ORDER: + tablesHint("JOIN_FIXED_ORDER"); + break; + case JOIN_INDEX: + indexesHint("JOIN_INDEX"); + break; + case NO_JOIN_INDEX: + indexesHint("NO_JOIN_INDEX"); + break; + case JOIN_ORDER: + tablesHint("JOIN_ORDER"); + break; + case JOIN_PREFIX: + tablesHint("JOIN_PREFIX"); + break; + case JOIN_SUFFIX: + tablesHint("JOIN_SUFFIX"); + break; + case MERGE: + tablesHint("MERGE"); + break; + case NO_MERGE: + tablesHint("NO_MERGE"); + break; + case MRR: + indexesHint("MRR"); + break; + case NO_MRR: + indexesHint("NO_MRR"); + break; + case NO_ICP: + indexesHint("NO_ICP"); + break; + case NO_RANGE_OPTIMIZATION: + indexesHint("NO_RANGE_OPTIMIZATION"); + break; + case ORDER_INDEX: + indexesHint("ORDER_INDEX"); + break; + case NO_ORDER_INDEX: + indexesHint("NO_ORDER_INDEX"); + break; + case SEMIJOIN: + semiHint("SEMIJOIN"); + break; + case NO_SEMIJOIN: + semiHint("NO_SEMIJOIN"); + break; + case SKIP_SCAN: + indexesHint("SKIP_SCAN"); + break; + case NO_SKIP_SCAN: + indexesHint("NO_SKIP_SCAN"); + break; + default: + throw new AssertionError(); + } + MySQLText hint = new MySQLText(sb.toString()); + select.setHint(hint); + return hint; + } + + private void indexesHint(String string) { + sb.append(string); + sb.append("("); + MySQLTable table = Randomly.fromList(tables); + List allIndexes = table.getIndexes(); + sb.append(table.getName()); + sb.append(", "); + if (allIndexes.isEmpty()) { + sb.append("PRIMARY"); + } else { + List indexSubset = Randomly.nonEmptySubset(allIndexes); + sb.append(indexSubset.stream().map(i -> i.getIndexName()).distinct().collect(Collectors.joining(", "))); + } + sb.append(")"); + } + + private void tablesHint(String string) { + sb.append(string); + sb.append("("); + appendTables(); + sb.append(")"); + } + + private void semiHint(String string) { + sb.append(string); + sb.append("("); + String[] options = { "DUPSWEEDOUT", "FIRSTMATCH", "LOOSESCAN", "MATERIALIZATION" }; + List chosenOptions = Randomly.nonEmptySubset(options); + sb.append(chosenOptions.stream().collect(Collectors.joining(", "))); + sb.append(")"); + } + + private void appendTables() { + List tableSubset = Randomly.nonEmptySubset(tables); + sb.append(tableSubset.stream().map(t -> t.getName()).collect(Collectors.joining(", "))); + } + +} diff --git a/src/sqlancer/mysql/gen/MySQLInsertGenerator.java b/src/sqlancer/mysql/gen/MySQLInsertGenerator.java index 496600905..86083fd2d 100644 --- a/src/sqlancer/mysql/gen/MySQLInsertGenerator.java +++ b/src/sqlancer/mysql/gen/MySQLInsertGenerator.java @@ -7,6 +7,7 @@ import sqlancer.Randomly; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.mysql.MySQLErrors; import sqlancer.mysql.MySQLGlobalState; import sqlancer.mysql.MySQLSchema.MySQLColumn; import sqlancer.mysql.MySQLSchema.MySQLTable; @@ -19,16 +20,21 @@ public class MySQLInsertGenerator { private final ExpectedErrors errors = new ExpectedErrors(); private final MySQLGlobalState globalState; - public MySQLInsertGenerator(MySQLGlobalState globalState) { + public MySQLInsertGenerator(MySQLGlobalState globalState, MySQLTable table) { this.globalState = globalState; - table = globalState.getSchema().getRandomTable(); + this.table = table; } public static SQLQueryAdapter insertRow(MySQLGlobalState globalState) throws SQLException { + MySQLTable table = globalState.getSchema().getRandomTable(); + return insertRow(globalState, table); + } + + public static SQLQueryAdapter insertRow(MySQLGlobalState globalState, MySQLTable table) throws SQLException { if (Randomly.getBoolean()) { - return new MySQLInsertGenerator(globalState).generateInsert(); + return new MySQLInsertGenerator(globalState, table).generateInsert(); } else { - return new MySQLInsertGenerator(globalState).generateReplace(); + return new MySQLInsertGenerator(globalState, table).generateReplace(); } } @@ -83,14 +89,7 @@ private SQLQueryAdapter generateInto() { } sb.append(")"); } - errors.add("doesn't have a default value"); - errors.add("Data truncation"); - errors.add("Incorrect integer value"); - errors.add("Duplicate entry"); - errors.add("Data truncated for functional index"); - errors.add("Data truncated for column"); - errors.add("cannot be null"); - errors.add("Incorrect decimal value"); + MySQLErrors.addInsertUpdateErrors(errors); return new SQLQueryAdapter(sb.toString(), errors); } diff --git a/src/sqlancer/mysql/gen/MySQLRandomQuerySynthesizer.java b/src/sqlancer/mysql/gen/MySQLRandomQuerySynthesizer.java new file mode 100644 index 000000000..d701072ee --- /dev/null +++ b/src/sqlancer/mysql/gen/MySQLRandomQuerySynthesizer.java @@ -0,0 +1,67 @@ +package sqlancer.mysql.gen; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.mysql.MySQLGlobalState; +import sqlancer.mysql.MySQLSchema.MySQLTables; +import sqlancer.mysql.ast.MySQLConstant; +import sqlancer.mysql.ast.MySQLExpression; +import sqlancer.mysql.ast.MySQLSelect; +import sqlancer.mysql.ast.MySQLTableReference; + +public final class MySQLRandomQuerySynthesizer { + + private MySQLRandomQuerySynthesizer() { + } + + public static MySQLSelect generate(MySQLGlobalState globalState, int nrColumns) { + MySQLTables tables = globalState.getSchema().getRandomTableNonEmptyTables(); + MySQLExpressionGenerator gen = new MySQLExpressionGenerator(globalState).setColumns(tables.getColumns()); + MySQLSelect select = new MySQLSelect(); + + List allColumns = new ArrayList<>(); + List columnsWithoutAggregations = new ArrayList<>(); + + boolean hasGeneratedAggregate = false; + + select.setSelectType(Randomly.fromOptions(MySQLSelect.SelectType.values())); + for (int i = 0; i < nrColumns; i++) { + if (Randomly.getBoolean()) { + MySQLExpression expression = gen.generateExpression(); + allColumns.add(expression); + columnsWithoutAggregations.add(expression); + } else { + allColumns.add(gen.generateAggregate()); + hasGeneratedAggregate = true; + } + } + select.setFetchColumns(allColumns); + + List tableList = tables.getTables().stream().map(t -> new MySQLTableReference(t)) + .collect(Collectors.toList()); + select.setFromList(tableList); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression()); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + if (hasGeneratedAggregate || Randomly.getBoolean()) { + select.setGroupByExpressions(columnsWithoutAggregations); + if (Randomly.getBoolean()) { + select.setHavingClause(gen.generateHavingClause()); + } + } + if (Randomly.getBoolean()) { + select.setLimitClause(MySQLConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + if (Randomly.getBoolean()) { + select.setOffsetClause(MySQLConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + } + } + return select; + } + +} diff --git a/src/sqlancer/mysql/gen/MySQLSetGenerator.java b/src/sqlancer/mysql/gen/MySQLSetGenerator.java index 2db20fee9..e350685ef 100644 --- a/src/sqlancer/mysql/gen/MySQLSetGenerator.java +++ b/src/sqlancer/mysql/gen/MySQLSetGenerator.java @@ -1,5 +1,7 @@ package sqlancer.mysql.gen; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.Function; import java.util.stream.Collectors; @@ -7,6 +9,7 @@ import sqlancer.MainOptions; import sqlancer.Randomly; import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.mysql.MySQLBugs; import sqlancer.mysql.MySQLGlobalState; public class MySQLSetGenerator { @@ -34,11 +37,11 @@ private enum Action { AUTOCOMMIT("autocommit", (r) -> 1, Scope.GLOBAL, Scope.SESSION), // BIG_TABLES("big_tables", (r) -> Randomly.fromOptions("OFF", "ON"), Scope.GLOBAL, Scope.SESSION), // - COMPLETION_TYPE("completion_type", (r) -> Randomly.fromOptions("'NO_CHAIN'", "'CHAIN'", "'RELEASE'", 0, 1, 2), - Scope.GLOBAL), // + COMPLETION_TYPE("completion_type", + (r) -> Randomly.fromOptions("'NO_CHAIN'", "'CHAIN'", "'RELEASE'", "0", "1", "2"), Scope.GLOBAL), // BULK_INSERT_CACHE_SIZE("bulk_insert_buffer_size", (r) -> r.getLong(0, Long.MAX_VALUE), Scope.GLOBAL, // Scope.SESSION), // - CONCURRENT_INSERT("concurrent_insert", (r) -> Randomly.fromOptions("NEVER", "AUTO", "ALWAYS", 0, 1, 2), // + CONCURRENT_INSERT("concurrent_insert", (r) -> Randomly.fromOptions("NEVER", "AUTO", "ALWAYS", "0", "1", "2"), // Scope.GLOBAL), // CTE_MAX_RECURSION_DEPTH("cte_max_recursion_depth", // (r) -> r.getLong(0, 4294967295L), Scope.GLOBAL), // @@ -66,8 +69,6 @@ private enum Action { MAX_SP_RECURSION_DEPTH("max_sp_recursion_depth", (r) -> r.getLong(0, 255), Scope.GLOBAL, Scope.SESSION), // MYISAM_DATA_POINTER_SIZE("myisam_data_pointer_size", (r) -> r.getLong(2, 7), Scope.GLOBAL), // MYISAM_MAX_SORT_FILE_SIZE("myisam_max_sort_file_size", (r) -> r.getLong(0, 9223372036854775807L), Scope.GLOBAL), // - MYISAM_REPAIR_THREADS("myisam_repair_threads", (r) -> r.getLong(1, Long.MAX_VALUE), Scope.GLOBAL, - Scope.SESSION), // MYISAM_SORT_BUFFER_SIZE("myisam_sort_buffer_size", (r) -> r.getLong(4096, Long.MAX_VALUE), Scope.GLOBAL, Scope.SESSION), // MYISAM_STATS_METHOD("myisam_stats_method", @@ -98,7 +99,6 @@ private enum Action { SCHEMA_DEFINITION_CACHE("schema_definition_cache", (r) -> r.getLong(256, 524288), Scope.GLOBAL), // SHOW_CREATE_TABLE_VERBOSITY("show_create_table_verbosity", (r) -> Randomly.fromOptions("OFF", "ON"), Scope.GLOBAL, Scope.SESSION), // - SHOW_OLD_TEMPORALS("show_old_temporals", (r) -> Randomly.fromOptions("OFF", "ON"), Scope.GLOBAL, Scope.SESSION), /* * sort_buffer_size is commented out as a workaround for https://bugs.mysql.com/bug.php?id=95969 */ @@ -133,11 +133,12 @@ private enum Action { private static String getOptimizerSwitchConfiguration(Randomly r) { StringBuilder sb = new StringBuilder(); sb.append("'"); - String[] options = { "batched_key_access", "block_nested_loop", "condition_fanout_filter", "derived_merge", - "engine_condition_pushdown", "index_condition_pushdown", "use_index_extensions", "index_merge", - "index_merge_intersection", "index_merge_sort_union", "index_merge_union", "use_invisible_indexes", - "mrr", "mrr_cost_based", "skip_scan", "semijoin", "duplicateweedout", "firstmatch", "loosescan", - "materialization", "subquery_materialization_cost_based" }; + String[] options = { "index_merge", "index_merge_union", "index_merge_sort_union", + "index_merge_intersection", "index_condition_pushdown", "mrr", "mrr_cost_based", + "block_nested_loop", "batched_key_access", "materialization", "semijoin", "loosescan", "firstmatch", + "duplicateweedout", "subquery_materialization_cost_based", "use_index_extensions", + "condition_fanout_filter", "derived_merge", "use_invisible_indexes", "skip_scan", "hash_join", + "subquery_to_derived", "prefer_ordering_index", "derived_condition_pushdown" }; List optionSubset = Randomly.nonEmptySubset(options); sb.append(optionSubset.stream().map(s -> s + "=" + Randomly.fromOptions("on", "off")) .collect(Collectors.joining(","))); @@ -190,4 +191,46 @@ private SQLQueryAdapter get() { return new SQLQueryAdapter(sb.toString()); } + public static SQLQueryAdapter resetOptimizer() { + return new SQLQueryAdapter("SET optimizer_switch='default'"); + } + + public static List getAllOptimizer(MySQLGlobalState globalState) { + List result = new ArrayList<>(); + String[] options = { "index_merge", "index_merge_union", "index_merge_sort_union", "index_merge_intersection", + "engine_condition_pushdown", "index_condition_pushdown", "mrr", "mrr_cost_based", "block_nested_loop", + "batched_key_access", "materialization", "semijoin", "loosescan", "firstmatch", "duplicateweedout", + "subquery_materialization_cost_based", "use_index_extensions", "condition_fanout_filter", + "derived_merge", "use_invisible_indexes", "skip_scan", "hash_join", "subquery_to_derived", + "prefer_ordering_index", "derived_condition_pushdown" }; + + List availableOptions = new ArrayList<>(Arrays.asList(options)); + if (MySQLBugs.bug112242) { + availableOptions.remove("use_invisible_indexes"); + } + if (MySQLBugs.bug112243) { + availableOptions.remove("subquery_to_derived"); + } + if (MySQLBugs.bug112264) { + availableOptions.remove("block_nested_loop"); + } + + StringBuilder sb = new StringBuilder(); + sb.append("SET "); + if (globalState.getOptions().getNumberConcurrentThreads() == 1 && Randomly.getBoolean()) { + sb.append("GLOBAL"); + } else { + sb.append("SESSION"); + } + sb.append(" optimizer_switch = '%s'"); + + for (String option : availableOptions) { + result.add(new SQLQueryAdapter(String.format(sb.toString(), option + "=on"))); + result.add(new SQLQueryAdapter(String.format(sb.toString(), option + "=off"))); + result.add(new SQLQueryAdapter(String.format(sb.toString(), option + "=default"))); + } + + return result; + } + } diff --git a/src/sqlancer/mysql/gen/MySQLTableGenerator.java b/src/sqlancer/mysql/gen/MySQLTableGenerator.java index 579d9a4c9..bc0533295 100644 --- a/src/sqlancer/mysql/gen/MySQLTableGenerator.java +++ b/src/sqlancer/mysql/gen/MySQLTableGenerator.java @@ -18,13 +18,11 @@ import sqlancer.mysql.MySQLSchema.MySQLTable.MySQLEngine; public class MySQLTableGenerator { - private final StringBuilder sb = new StringBuilder(); private final boolean allowPrimaryKey; private boolean setPrimaryKey; private final String tableName; private final Randomly r; - private int columnId; private boolean tableHasNullableColumn; private MySQLEngine engine; private int keysSpecified; @@ -65,18 +63,18 @@ private SQLQueryAdapter create() { if (i != 0) { sb.append(", "); } - appendColumn(); + appendColumn(i); } sb.append(")"); sb.append(" "); appendTableOptions(); appendPartitionOptions(); - if ((tableHasNullableColumn || setPrimaryKey) && engine == MySQLEngine.CSV) { + if (engine == MySQLEngine.CSV && (tableHasNullableColumn || setPrimaryKey)) { if (true) { // TODO // results in an error throw new IgnoreMeException(); } - } else if ((tableHasNullableColumn || keysSpecified > 1) && engine == MySQLEngine.ARCHIVE) { + } else if (engine == MySQLEngine.ARCHIVE && (tableHasNullableColumn || keysSpecified > 1)) { errors.add("Too many keys specified; max 1 keys allowed"); errors.add("Table handler doesn't support NULL in given index"); addCommonErrors(errors); @@ -176,9 +174,10 @@ private void appendTableOptions() { sb.append("AUTO_INCREMENT = "); sb.append(r.getPositiveInteger()); break; + // The valid range for avg_row_length is [0,4294967295] case AVG_ROW_LENGTH: sb.append("AVG_ROW_LENGTH = "); - sb.append(r.getPositiveInteger()); + sb.append(r.getLong(0, 4294967295L + 1)); break; case CHECKSUM: sb.append("CHECKSUM = 1"); @@ -212,9 +211,10 @@ private void appendTableOptions() { sb.append("INSERT_METHOD = "); sb.append(Randomly.fromOptions("NO", "FIRST", "LAST")); break; + // The valid range for key_block_size is [0,65535] case KEY_BLOCK_SIZE: sb.append("KEY_BLOCK_SIZE = "); - sb.append(r.getPositiveInteger()); + sb.append(r.getInteger(0, 65535 + 1)); break; case MAX_ROWS: sb.append("MAX_ROWS = "); @@ -246,27 +246,21 @@ private void appendTableOptions() { } } - private void appendColumn() { + private void appendColumn(int columnId) { String columnName = DBMSCommon.createColumnName(columnId); columns.add(columnName); sb.append(columnName); appendColumnDefinition(); - columnId++; } private enum ColumnOptions { NULL_OR_NOT_NULL, UNIQUE, COMMENT, COLUMN_FORMAT, STORAGE, PRIMARY_KEY } - private void appendColumnDefinition() { - sb.append(" "); - MySQLDataType randomType = MySQLDataType.getRandom(globalState); - boolean isTextType = randomType == MySQLDataType.VARCHAR; - appendTypeString(randomType); - sb.append(" "); + private void appendColumnOption(MySQLDataType type) { + boolean isTextType = type == MySQLDataType.VARCHAR; boolean isNull = false; boolean columnHasPrimaryKey = false; - List columnOptions = Randomly.subset(ColumnOptions.values()); if (!columnOptions.contains(ColumnOptions.NULL_OR_NOT_NULL)) { tableHasNullableColumn = true; @@ -322,10 +316,17 @@ private void appendColumnDefinition() { throw new AssertionError(); } } + } + private void appendColumnDefinition() { + sb.append(" "); + MySQLDataType randomType = MySQLDataType.getRandom(globalState); + appendType(randomType); + sb.append(" "); + appendColumnOption(randomType); } - private void appendTypeString(MySQLDataType randomType) { + private void appendType(MySQLDataType randomType) { switch (randomType) { case DECIMAL: sb.append("DECIMAL"); @@ -336,7 +337,7 @@ private void appendTypeString(MySQLDataType randomType) { if (Randomly.getBoolean()) { sb.append("("); sb.append(Randomly.getNotCachedInteger(0, 255)); // Display width out of range for column 'c0' (max = - // 255) + // 255) sb.append(")"); } break; diff --git a/src/sqlancer/mysql/gen/MySQLUpdateGenerator.java b/src/sqlancer/mysql/gen/MySQLUpdateGenerator.java new file mode 100644 index 000000000..55ba3dd45 --- /dev/null +++ b/src/sqlancer/mysql/gen/MySQLUpdateGenerator.java @@ -0,0 +1,58 @@ +package sqlancer.mysql.gen; + +import java.sql.SQLException; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.gen.AbstractUpdateGenerator; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.mysql.MySQLErrors; +import sqlancer.mysql.MySQLGlobalState; +import sqlancer.mysql.MySQLSchema.MySQLColumn; +import sqlancer.mysql.MySQLSchema.MySQLTable; +import sqlancer.mysql.MySQLVisitor; + +public class MySQLUpdateGenerator extends AbstractUpdateGenerator { + + private final MySQLGlobalState globalState; + private MySQLExpressionGenerator gen; + + public MySQLUpdateGenerator(MySQLGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter create(MySQLGlobalState globalState) throws SQLException { + return new MySQLUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() throws SQLException { + MySQLTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyColumnSubset(); + gen = new MySQLExpressionGenerator(globalState).setColumns(table.getColumns()); + sb.append("UPDATE "); + sb.append(table.getName()); + sb.append(" SET "); + updateColumns(columns); + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + MySQLErrors.addExpressionErrors(errors); + sb.append(MySQLVisitor.asString(gen.generateExpression())); + } + MySQLErrors.addInsertUpdateErrors(errors); + errors.add("doesn't have this option"); + + return new SQLQueryAdapter(sb.toString(), errors); + } + + @Override + protected void updateValue(MySQLColumn column) { + if (Randomly.getBoolean()) { + sb.append(gen.generateConstant()); + } else if (Randomly.getBoolean()) { + sb.append("DEFAULT"); + } else { + sb.append(MySQLVisitor.asString(gen.generateExpression())); + } + } + +} diff --git a/src/sqlancer/mysql/gen/datadef/MySQLIndexGenerator.java b/src/sqlancer/mysql/gen/datadef/MySQLIndexGenerator.java index 3f418385d..550893db5 100644 --- a/src/sqlancer/mysql/gen/datadef/MySQLIndexGenerator.java +++ b/src/sqlancer/mysql/gen/datadef/MySQLIndexGenerator.java @@ -5,6 +5,7 @@ import sqlancer.Randomly; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.mysql.MySQLBugs; import sqlancer.mysql.MySQLErrors; import sqlancer.mysql.MySQLGlobalState; import sqlancer.mysql.MySQLSchema; @@ -78,7 +79,11 @@ public SQLQueryAdapter create() { if (Randomly.getBoolean() && c.getType() == MySQLDataType.VARCHAR) { sb.append("("); // TODO for string - sb.append(r.getInteger(1, 5)); + if (MySQLBugs.bug114534) { + sb.append(r.getInteger(2, 5)); + } else { + sb.append(r.getInteger(1, 5)); + } sb.append(")"); } if (Randomly.getBoolean()) { diff --git a/src/sqlancer/mysql/oracle/MySQLDQEOracle.java b/src/sqlancer/mysql/oracle/MySQLDQEOracle.java new file mode 100644 index 000000000..8ddb6f315 --- /dev/null +++ b/src/sqlancer/mysql/oracle/MySQLDQEOracle.java @@ -0,0 +1,498 @@ +package sqlancer.mysql.oracle; + +import static sqlancer.ComparatorHelper.getResultSetFirstColumnAsString; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import com.beust.jcommander.Strings; + +import sqlancer.Randomly; +import sqlancer.common.oracle.DQEBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryError; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTables; +import sqlancer.mysql.MySQLErrors; +import sqlancer.mysql.MySQLGlobalState; +import sqlancer.mysql.MySQLSchema; +import sqlancer.mysql.MySQLSchema.MySQLColumn; +import sqlancer.mysql.MySQLSchema.MySQLTable; +import sqlancer.mysql.MySQLSchema.MySQLTables; +import sqlancer.mysql.MySQLVisitor; +import sqlancer.mysql.ast.MySQLExpression; +import sqlancer.mysql.gen.MySQLExpressionGenerator; + +public class MySQLDQEOracle extends DQEBase implements TestOracle { + + private final MySQLSchema schema; + private static final String APPEND_ORDER_BY = "%s ORDER BY %s"; + private static final String APPEND_LIMIT = "%s LIMIT %d"; + private final List orderColumns = new ArrayList<>(); + private boolean generateLimit; + private boolean generateOrderBy; + private boolean operateOnSingleTable; + private int limit; + + public MySQLDQEOracle(MySQLGlobalState state) { + super(state); + schema = state.getSchema(); + + MySQLErrors.addExpressionErrors(selectExpectedErrors); + + MySQLErrors.addExpressionErrors(updateExpectedErrors); + MySQLErrors.addInsertUpdateErrors(updateExpectedErrors); + + MySQLErrors.addExpressionErrors(deleteExpectedErrors); + deleteExpectedErrors.add("a foreign key constraint fails"); + } + + @Override + public String generateSelectStatement(AbstractTables tables, String tableName, String whereClauseStr) { + operateOnSingleTable = tables.getTables().size() == 1; + List selectColumns = new ArrayList<>(); + MySQLTables mySQLTables = (MySQLTables) tables; + for (MySQLTable table : mySQLTables.getTables()) { + selectColumns.add(table.getName() + "." + COLUMN_ROWID); + } + if (operateOnSingleTable && Randomly.getBooleanWithSmallProbability()) { + generateOrderBy = true; + // generate order by columns + for (MySQLColumn column : Randomly.nonEmptySubset(mySQLTables.getColumns())) { + orderColumns.add(column.getFullQualifiedName()); + } + + if (Randomly.getBooleanWithRatherLowProbability()) { + generateLimit = true; + limit = (int) Randomly.getNotCachedInteger(1, 10); + } + } + + String selectStmt = String.format("SELECT %s FROM %s WHERE %s", Strings.join(",", selectColumns).toLowerCase(), + tableName, whereClauseStr); + if (generateOrderBy) { + selectStmt = String.format(APPEND_ORDER_BY, selectStmt, String.join(",", orderColumns)); + if (generateLimit) { + selectStmt = String.format(APPEND_LIMIT, selectStmt, limit); + } + } + return selectStmt; + } + + @Override + public String generateUpdateStatement(AbstractTables tables, String tableName, String whereClauseStr) { + List updateColumns = new ArrayList<>(); + MySQLTables mySQLTables = (MySQLTables) tables; + for (MySQLTable table : mySQLTables.getTables()) { + updateColumns.add(String.format("%s = 1", table.getName() + "." + COLUMN_UPDATED)); + } + String updateStmt = String.format("UPDATE %s SET %s WHERE %s", tableName, Strings.join(",", updateColumns), + whereClauseStr); + if (generateOrderBy) { + updateStmt = String.format(APPEND_ORDER_BY, updateStmt, String.join(",", orderColumns)); + if (generateLimit) { + updateStmt = String.format(APPEND_LIMIT, updateStmt, limit); + } + } + return updateStmt; + } + + @Override + public String generateDeleteStatement(String tableName, String whereClauseStr) { + String deleteStmt; + if (operateOnSingleTable) { + deleteStmt = String.format("DELETE FROM %s WHERE %s", tableName, whereClauseStr); + if (generateOrderBy) { + deleteStmt = String.format(APPEND_ORDER_BY, deleteStmt, String.join(",", orderColumns)); + if (generateLimit) { + deleteStmt = String.format(APPEND_LIMIT, deleteStmt, limit); + } + } + } else { + deleteStmt = String.format("DELETE %s FROM %s WHERE %s", tableName, tableName, whereClauseStr); + } + return deleteStmt; + } + + @Override + public void check() throws SQLException { + + MySQLTables tables = schema.getRandomTableNonEmptyTables(); + String tableName = tables.getTables().stream().map(AbstractTable::getName).collect(Collectors.joining(",")); + + // DQE does not support aggregate functions, windows functions + // This method does not generate them, may need some configurations if they can be generated + MySQLExpressionGenerator expressionGenerator = new MySQLExpressionGenerator(state) + .setColumns(tables.getColumns()); + MySQLExpression whereClause = expressionGenerator.generateExpression(); + + // MySQLVisitor is not deterministic, we should keep it only once. + // Especially, in MySQLUnaryPostfixOperation and MySQLUnaryPrefixOperation + String whereClauseStr = MySQLVisitor.asString(whereClause); + + String selectStmt = generateSelectStatement(tables, tableName, whereClauseStr); + + String updateStmt = generateUpdateStatement(tables, tableName, whereClauseStr); + + String deleteStmt = generateDeleteStatement(tableName, whereClauseStr); + + for (MySQLTable table : tables.getTables()) { + addAuxiliaryColumns(table); + } + + state.getState().getLocalState().log(selectStmt); + SQLQueryResult selectExecutionResult = executeSelect(selectStmt, tables); + state.getState().getLocalState().log(selectExecutionResult.getAccessedRows().values().toString()); + state.getState().getLocalState().log(selectExecutionResult.getQueryErrors().toString()); + + state.getState().getLocalState().log(updateStmt); + SQLQueryResult updateExecutionResult = executeUpdate(updateStmt, tables); + state.getState().getLocalState().log(updateExecutionResult.getAccessedRows().values().toString()); + state.getState().getLocalState().log(updateExecutionResult.getQueryErrors().toString()); + + state.getState().getLocalState().log(deleteStmt); + SQLQueryResult deleteExecutionResult = executeDelete(deleteStmt, tables); + state.getState().getLocalState().log(deleteExecutionResult.getAccessedRows().values().toString()); + state.getState().getLocalState().log(deleteExecutionResult.getQueryErrors().toString()); + + String compareSelectAndUpdate = compareSelectAndUpdate(selectExecutionResult, updateExecutionResult); + String compareSelectAndDelete = compareSelectAndDelete(selectExecutionResult, deleteExecutionResult); + String compareUpdateAndDelete = compareUpdateAndDelete(updateExecutionResult, deleteExecutionResult); + + String errorMessage = compareSelectAndUpdate == null ? "" : compareSelectAndUpdate + "\n"; + errorMessage += compareSelectAndDelete == null ? "" : compareSelectAndDelete + "\n"; + errorMessage += compareUpdateAndDelete == null ? "" : compareUpdateAndDelete + "\n"; + + if (!errorMessage.isEmpty()) { + throw new AssertionError(errorMessage); + } + + for (MySQLTable table : tables.getTables()) { + dropAuxiliaryColumns(table); + } + } + + public String compareSelectAndUpdate(SQLQueryResult selectResult, SQLQueryResult updateResult) { + if (updateResult.hasEmptyErrors()) { + if (!selectResult.hasEmptyErrors()) { + return "SELECT has errors, but UPDATE does not."; + } + if (!selectResult.hasSameAccessedRows(updateResult)) { + return "SELECT accessed different rows from UPDATE."; + } + } else { // update has errors + if (hasUpdateSpecificErrors(updateResult)) { + if (updateResult.hasAccessedRows()) { + return "UPDATE accessed non-empty rows when specific errors happen."; + } else { + // we do not compare update with select when update has specific errors + return null; + } + } + + // update errors should all appear in the select errors + List selectErrors = new ArrayList<>(selectResult.getQueryErrors()); + for (int i = 0; i < updateResult.getQueryErrors().size(); i++) { + SQLQueryError updateError = updateResult.getQueryErrors().get(i); + if (!isFound(selectErrors, updateError)) { + return "SELECT has different errors from UPDATE."; + } + } + + if (hasStopErrors(updateResult)) { + if (updateResult.hasAccessedRows()) { + return "UPDATE accessed non-empty rows when stop errors happen."; + } + } else { + if (!selectResult.hasSameAccessedRows(updateResult)) { + return "SELECT accessed different rows from UPDATE when errors happen."; + } + } + } + return null; + } + + /** + * + * @param selectErrors + * selectQueryErrors + * @param targetError + * update or delete queryError + * + * @return is targetError found in selectQueryErrors + */ + private static boolean isFound(List selectErrors, SQLQueryError targetError) { + boolean found = false; + for (int i = 0; i < selectErrors.size(); i++) { + SQLQueryError selectError = selectErrors.get(i); + if (selectError.hasSameCodeAndMessage(targetError)) { + selectErrors.remove(i); + found = true; + break; + } + } + return found; + } + + public String compareSelectAndDelete(SQLQueryResult selectResult, SQLQueryResult deleteResult) { + if (deleteResult.hasEmptyErrors()) { + if (!selectResult.hasEmptyErrors()) { + return "SELECT has errors, but DELETE does not."; + } + if (!selectResult.hasSameAccessedRows(deleteResult)) { + return "SELECT accessed different rows from DELETE."; + } + } else { // delete has errors + if (hasDeleteSpecificErrors(deleteResult)) { + if (deleteResult.hasAccessedRows()) { + return "DELETE accessed non-empty rows when specific errors happen."; + } else { + // we do not compare delete with select when delete has specific errors + return null; + } + } + + // delete errors should all appear in the select errors + List selectErrors = new ArrayList<>(selectResult.getQueryErrors()); + for (int i = 0; i < deleteResult.getQueryErrors().size(); i++) { + SQLQueryError deleteError = deleteResult.getQueryErrors().get(i); + if (!isFound(selectErrors, deleteError)) { + return "SELECT has different errors from DELETE."; + } + } + + if (hasStopErrors(deleteResult)) { + if (deleteResult.hasAccessedRows()) { + return "DELETE accessed non-empty rows when stop errors happen."; + } + } else { + if (!selectResult.hasSameAccessedRows(deleteResult)) { + return "SELECT accessed different rows from DELETE when errors happen."; + } + } + } + return null; + } + + public String compareUpdateAndDelete(SQLQueryResult updateResult, SQLQueryResult deleteResult) { + if (updateResult.hasEmptyErrors() && deleteResult.hasEmptyErrors()) { + if (updateResult.hasSameAccessedRows(deleteResult)) { + return null; + } else { + return "UPDATE accessed different rows from DELETE."; + } + } else { // update or delete has errors + boolean hasSpecificErrors = false; + + if (hasUpdateSpecificErrors(updateResult)) { + hasSpecificErrors = true; + if (updateResult.hasAccessedRows()) { + return "UPDATE accessed non-empty rows when specific errors happen."; + } + } + + if (hasDeleteSpecificErrors(deleteResult)) { + hasSpecificErrors = true; + if (deleteResult.hasAccessedRows()) { + return "DELETE accessed non-empty rows when specific errors happen."; + } + } + + // when one of these statements has specific errors, do not compare them + if (hasSpecificErrors) { + return null; + } + + if (!updateResult.hasSameErrors(deleteResult)) { + return "UPDATE has different errors from DELETE."; + } else { + if (!hasStopErrors(updateResult)) { + if (!updateResult.hasSameAccessedRows(deleteResult)) { + return "UPDATE accessed different rows from DELETE."; + } + } else { + if (updateResult.hasAccessedRows() || deleteResult.hasAccessedRows()) { + return "UPDATE or DELETE accessed non-empty rows when stop errors happen."; + } + } + } + + return null; + } + } + + /* + * when update violates column constraints, such as not null, unique, primary key and generated column, we cannot + * compare it with other queries. + */ + private boolean hasUpdateSpecificErrors(SQLQueryResult updateResult) { + return updateResult.getQueryErrors().stream().anyMatch( + error -> new MySQLErrorCodeStrategy().getUpdateSpecificErrorCodes().contains(error.getCode())); + } + + /* + * when delete violates column constraints, such as foreign key, we cannot compare it with other queries. + */ + private boolean hasDeleteSpecificErrors(SQLQueryResult deleteResult) { + return deleteResult.getQueryErrors().stream().anyMatch( + error -> new MySQLErrorCodeStrategy().getDeleteSpecificErrorCodes().contains(error.getCode())); + + } + + private boolean hasStopErrors(SQLQueryResult queryResult) { + return queryResult.getQueryErrors().stream() + .anyMatch(error -> error.getLevel() == SQLQueryError.ErrorLevel.ERROR); + } + + private SQLQueryResult executeSelect(String selectStmt, MySQLTables tables) throws SQLException { + Map, Set> accessedRows = new HashMap<>(); + List queryErrors; + SQLancerResultSet resultSet = null; + try { + resultSet = new SQLQueryAdapter(selectStmt, selectExpectedErrors).executeAndGet(state, false); + } catch (SQLException ignored) { + // we ignore this error, and use get errors to catch it + } finally { + queryErrors = getErrors(); + + if (resultSet != null) { + for (MySQLTable table : tables.getTables()) { + HashSet rows = new HashSet<>(); + accessedRows.put(table, rows); + } + while (resultSet.next()) { + for (MySQLTable table : tables.getTables()) { + accessedRows.get(table).add(resultSet.getString(table.getName() + "." + COLUMN_ROWID)); + } + } + resultSet.close(); + } + } + + return new SQLQueryResult(accessedRows, queryErrors); + } + + private SQLQueryResult executeUpdate(String updateStmt, MySQLTables tables) throws SQLException { + Map, Set> accessedRows = new HashMap<>(); + List queryErrors; + try { + new SQLQueryAdapter("BEGIN").execute(state, false); + new SQLQueryAdapter(updateStmt, updateExpectedErrors).execute(state, false); + } catch (SQLException ignored) { + // we ignore this error, and we use get errors to catch it + } finally { + queryErrors = getErrors(); + + for (MySQLTable table : tables.getTables()) { + String tableName = table.getName(); + String rowId = tableName + "." + COLUMN_ROWID; + String updated = tableName + "." + COLUMN_UPDATED; + String selectRowIdWithUpdated = String.format("SELECT %s FROM %s WHERE %s = 1", rowId, tableName, + updated); + HashSet rows = new HashSet<>( + getResultSetFirstColumnAsString(selectRowIdWithUpdated, updateExpectedErrors, state)); + accessedRows.put(table, rows); + } + + new SQLQueryAdapter("ROLLBACK").execute(state, false); + } + + return new SQLQueryResult(accessedRows, queryErrors); + } + + private SQLQueryResult executeDelete(String deleteStmt, MySQLTables tables) throws SQLException { + Map, Set> accessedRows = new HashMap<>(); + List queryErrors; + try { + for (MySQLTable table : tables.getTables()) { + String tableName = table.getName(); + String rowId = tableName + "." + COLUMN_ROWID; + String selectRowId = String.format("SELECT %s FROM %s", rowId, tableName); + HashSet rows = new HashSet<>( + getResultSetFirstColumnAsString(selectRowId, deleteExpectedErrors, state)); + accessedRows.put(table, rows); + } + + new SQLQueryAdapter("BEGIN").execute(state, false); + new SQLQueryAdapter(deleteStmt, deleteExpectedErrors).execute(state, false); + } catch (SQLException ignored) { + // we ignore this error, and use get errors to catch it + } finally { + queryErrors = getErrors(); + + for (MySQLTable table : tables.getTables()) { + String tableName = table.getName(); + String rowId = tableName + "." + COLUMN_ROWID; + String selectRowId = String.format("SELECT %s FROM %s", rowId, tableName); + HashSet rows = new HashSet<>( + getResultSetFirstColumnAsString(selectRowId, deleteExpectedErrors, state)); + accessedRows.get(table).removeAll(rows); + } + + new SQLQueryAdapter("ROLLBACK").execute(state, false); + } + + return new SQLQueryResult(accessedRows, queryErrors); + } + + private List getErrors() throws SQLException { + SQLancerResultSet resultSet = new SQLQueryAdapter("SHOW WARNINGS").executeAndGet(state, false); + List queryErrors = new ArrayList<>(); + if (resultSet != null) { + while (resultSet.next()) { + SQLQueryError queryError = new SQLQueryError(); + queryError.setLevel(resultSet.getString("Level").equalsIgnoreCase("ERROR") + ? SQLQueryError.ErrorLevel.ERROR : SQLQueryError.ErrorLevel.WARNING); + queryError.setCode(resultSet.getInt("Code")); + queryError.setMessage(resultSet.getString("Message")); + queryErrors.add(queryError); + } + resultSet.close(); + } + Collections.sort(queryErrors); + return queryErrors; + } + + @Override + public void addAuxiliaryColumns(AbstractRelationalTable table) throws SQLException { + String tableName = table.getName(); + + String addColumnRowID = String.format("ALTER TABLE %s ADD %s TEXT", tableName, COLUMN_ROWID); + new SQLQueryAdapter(addColumnRowID).execute(state, false); + state.getState().getLocalState().log(addColumnRowID); + + String addColumnUpdated = String.format("ALTER TABLE %s ADD %s INT DEFAULT 0", tableName, COLUMN_UPDATED); + new SQLQueryAdapter(addColumnUpdated).execute(state, false); + state.getState().getLocalState().log(addColumnUpdated); + + String updateRowsWithUniqueID = String.format("UPDATE %s SET %s = UUID()", tableName, COLUMN_ROWID); + new SQLQueryAdapter(updateRowsWithUniqueID).execute(state, false); + state.getState().getLocalState().log(updateRowsWithUniqueID); + } + + public static class MySQLErrorCodeStrategy implements ErrorCodeStrategy { + @Override + public Set getUpdateSpecificErrorCodes() { + // 1048, Column 'c0' cannot be null + // 1062, Duplicate entry '2' for key 't1.i0 + // 3105, The value specified for generated column 'c1' in table 't1' is not allowed + return Set.of(1048, 1062, 3105); + } + + @Override + public Set getDeleteSpecificErrorCodes() { + // 1451, Cannot delete or update a parent row: a foreign key constraint fails + return Set.of(1451); + } + } +} diff --git a/src/sqlancer/mysql/oracle/MySQLDQPOracle.java b/src/sqlancer/mysql/oracle/MySQLDQPOracle.java new file mode 100644 index 000000000..414ffb156 --- /dev/null +++ b/src/sqlancer/mysql/oracle/MySQLDQPOracle.java @@ -0,0 +1,104 @@ +package sqlancer.mysql.oracle; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.mysql.MySQLErrors; +import sqlancer.mysql.MySQLGlobalState; +import sqlancer.mysql.MySQLSchema.MySQLTables; +import sqlancer.mysql.MySQLVisitor; +import sqlancer.mysql.ast.MySQLColumnReference; +import sqlancer.mysql.ast.MySQLExpression; +import sqlancer.mysql.ast.MySQLJoin; +import sqlancer.mysql.ast.MySQLSelect; +import sqlancer.mysql.ast.MySQLTableReference; +import sqlancer.mysql.ast.MySQLText; +import sqlancer.mysql.gen.MySQLExpressionGenerator; +import sqlancer.mysql.gen.MySQLHintGenerator; +import sqlancer.mysql.gen.MySQLSetGenerator; + +public class MySQLDQPOracle implements TestOracle { + private final MySQLGlobalState state; + private MySQLExpressionGenerator gen; + private MySQLSelect select; + private final ExpectedErrors errors = new ExpectedErrors(); + + public MySQLDQPOracle(MySQLGlobalState globalState) { + state = globalState; + MySQLErrors.addExpressionErrors(errors); + } + + @Override + public void check() throws Exception { + // Randomly generate a query + MySQLTables tables = state.getSchema().getRandomTableNonEmptyTables(); + gen = new MySQLExpressionGenerator(state).setColumns(tables.getColumns()); + List fetchColumns = new ArrayList<>(); + fetchColumns.addAll(Randomly.nonEmptySubset(tables.getColumns()).stream() + .map(c -> new MySQLColumnReference(c, null)).collect(Collectors.toList())); + + select = new MySQLSelect(); + select.setFetchColumns(fetchColumns); + + select.setSelectType(Randomly.fromOptions(MySQLSelect.SelectType.values())); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression()); + } + if (Randomly.getBoolean()) { + select.setGroupByExpressions(fetchColumns); + if (Randomly.getBoolean()) { + select.setHavingClause(gen.generateExpression()); + } + } + + // Set the join. + List joinExpressions = MySQLJoin.getRandomJoinClauses(tables.getTables(), state); + select.setJoinList(joinExpressions.stream().map(j -> (MySQLExpression) j).collect(Collectors.toList())); + + // Set the from clause from the tables that are not used in the join. + List tableList = tables.getTables().stream().map(t -> new MySQLTableReference(t)) + .collect(Collectors.toList()); + select.setFromList(tableList); + + // Get the result of the first query + String originalQueryString = MySQLVisitor.asString(select); + List originalResult = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, + state); + + // Check hints + List hintList = MySQLHintGenerator.generateAllHints(select, tables.getTables()); + for (MySQLText hint : hintList) { + select.setHint(hint); + String queryString = MySQLVisitor.asString(select); + List result = ComparatorHelper.getResultSetFirstColumnAsString(queryString, errors, state); + ComparatorHelper.assumeResultSetsAreEqual(originalResult, result, originalQueryString, List.of(queryString), + state); + } + + // Check optimizer variables + List optimizationList = MySQLSetGenerator.getAllOptimizer(state); + for (SQLQueryAdapter optimization : optimizationList) { + optimization.execute(state); + List result = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + try { + ComparatorHelper.assumeResultSetsAreEqual(originalResult, result, originalQueryString, + List.of(originalQueryString), state); + } catch (AssertionError e) { + String assertionMessage = String.format( + "The size of the result sets mismatch (%d and %d)!" + System.lineSeparator() + + "First query: \"%s\", whose cardinality is: %d" + System.lineSeparator() + + "Second query:\"%s\", whose cardinality is: %d", + originalResult.size(), result.size(), originalQueryString, originalResult.size(), + String.join(";", originalQueryString), result.size()); + assertionMessage += System.lineSeparator() + "The setting: " + optimization.getQueryString(); + throw new AssertionError(assertionMessage); + } + } + } +} diff --git a/src/sqlancer/mysql/oracle/MySQLFuzzer.java b/src/sqlancer/mysql/oracle/MySQLFuzzer.java new file mode 100644 index 000000000..2e361c0b8 --- /dev/null +++ b/src/sqlancer/mysql/oracle/MySQLFuzzer.java @@ -0,0 +1,30 @@ +package sqlancer.mysql.oracle; + +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.mysql.MySQLGlobalState; +import sqlancer.mysql.MySQLVisitor; +import sqlancer.mysql.gen.MySQLRandomQuerySynthesizer; + +public class MySQLFuzzer implements TestOracle { + + private final MySQLGlobalState globalState; + + public MySQLFuzzer(MySQLGlobalState globalState) { + this.globalState = globalState; + } + + @Override + public void check() throws Exception { + String s = MySQLVisitor.asString(MySQLRandomQuerySynthesizer.generate(globalState, Randomly.smallNumber() + 1)) + + ';'; + try { + globalState.executeStatement(new SQLQueryAdapter(s)); + globalState.getManager().incrementSelectQueryCount(); + } catch (Error e) { + + } + } + +} diff --git a/src/sqlancer/mysql/oracle/MySQLPivotedQuerySynthesisOracle.java b/src/sqlancer/mysql/oracle/MySQLPivotedQuerySynthesisOracle.java index 6e50fbb33..c1fe893b6 100644 --- a/src/sqlancer/mysql/oracle/MySQLPivotedQuerySynthesisOracle.java +++ b/src/sqlancer/mysql/oracle/MySQLPivotedQuerySynthesisOracle.java @@ -38,6 +38,7 @@ public MySQLPivotedQuerySynthesisOracle(MySQLGlobalState globalState) throws SQL super(globalState); MySQLErrors.addExpressionErrors(errors); errors.add("in 'order clause'"); // e.g., Unknown column '2067708013' in 'order clause' + errors.add("in 'EXISTS subquery'"); // e.g., Unknown column '2067708013' in 'EXISTS subquery' (MySQL 8.4+) } @Override @@ -68,7 +69,7 @@ public Query getRectifiedQuery() throws SQLException { selectStatement.setModifiers(modifiers); List orderBy = new MySQLExpressionGenerator(globalState).setColumns(columns) .generateOrderBys(); - selectStatement.setOrderByExpressions(orderBy); + selectStatement.setOrderByClauses(orderBy); return new SQLQueryAdapter(MySQLVisitor.asString(selectStatement), errors); } diff --git a/src/sqlancer/mysql/oracle/MySQLQueryPartitioningBase.java b/src/sqlancer/mysql/oracle/MySQLQueryPartitioningBase.java deleted file mode 100644 index 8946136fe..000000000 --- a/src/sqlancer/mysql/oracle/MySQLQueryPartitioningBase.java +++ /dev/null @@ -1,61 +0,0 @@ -package sqlancer.mysql.oracle; - -import java.sql.SQLException; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import sqlancer.common.gen.ExpressionGenerator; -import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.mysql.MySQLErrors; -import sqlancer.mysql.MySQLGlobalState; -import sqlancer.mysql.MySQLSchema; -import sqlancer.mysql.MySQLSchema.MySQLTable; -import sqlancer.mysql.MySQLSchema.MySQLTables; -import sqlancer.mysql.ast.MySQLColumnReference; -import sqlancer.mysql.ast.MySQLExpression; -import sqlancer.mysql.ast.MySQLSelect; -import sqlancer.mysql.ast.MySQLTableReference; -import sqlancer.mysql.gen.MySQLExpressionGenerator; - -public abstract class MySQLQueryPartitioningBase - extends TernaryLogicPartitioningOracleBase implements TestOracle { - - MySQLSchema s; - MySQLTables targetTables; - MySQLExpressionGenerator gen; - MySQLSelect select; - - public MySQLQueryPartitioningBase(MySQLGlobalState state) { - super(state); - MySQLErrors.addExpressionErrors(errors); - } - - @Override - public void check() throws SQLException { - s = state.getSchema(); - targetTables = s.getRandomTableNonEmptyTables(); - gen = new MySQLExpressionGenerator(state).setColumns(targetTables.getColumns()); - initializeTernaryPredicateVariants(); - select = new MySQLSelect(); - select.setFetchColumns(generateFetchColumns()); - List tables = targetTables.getTables(); - List tableList = tables.stream().map(t -> new MySQLTableReference(t)) - .collect(Collectors.toList()); - // List joins = MySQLJoin.getJoins(tableList, state); - select.setFromList(tableList); - select.setWhereClause(null); - // select.setJoins(joins); - } - - List generateFetchColumns() { - return Arrays.asList(MySQLColumnReference.create(targetTables.getColumns().get(0), null)); - } - - @Override - protected ExpressionGenerator getGen() { - return gen; - } - -} diff --git a/src/sqlancer/mysql/oracle/MySQLTLPWhereOracle.java b/src/sqlancer/mysql/oracle/MySQLTLPWhereOracle.java deleted file mode 100644 index 4b578a24d..000000000 --- a/src/sqlancer/mysql/oracle/MySQLTLPWhereOracle.java +++ /dev/null @@ -1,44 +0,0 @@ -package sqlancer.mysql.oracle; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.Randomly; -import sqlancer.mysql.MySQLGlobalState; -import sqlancer.mysql.MySQLVisitor; - -public class MySQLTLPWhereOracle extends MySQLQueryPartitioningBase { - - public MySQLTLPWhereOracle(MySQLGlobalState state) { - super(state); - } - - @Override - public void check() throws SQLException { - super.check(); - select.setWhereClause(null); - String originalQueryString = MySQLVisitor.asString(select); - - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - if (Randomly.getBoolean()) { - select.setOrderByExpressions(gen.generateOrderBys()); - } - select.setOrderByExpressions(Collections.emptyList()); - select.setWhereClause(predicate); - String firstQueryString = MySQLVisitor.asString(select); - select.setWhereClause(negatedPredicate); - String secondQueryString = MySQLVisitor.asString(select); - select.setWhereClause(isNullPredicate); - String thirdQueryString = MySQLVisitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, Randomly.getBoolean(), state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state); - } - -} diff --git a/src/sqlancer/oceanbase/OceanBaseErrors.java b/src/sqlancer/oceanbase/OceanBaseErrors.java index 893490c34..7a5bd8a77 100644 --- a/src/sqlancer/oceanbase/OceanBaseErrors.java +++ b/src/sqlancer/oceanbase/OceanBaseErrors.java @@ -1,5 +1,9 @@ package sqlancer.oceanbase; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + import sqlancer.common.query.ExpectedErrors; public final class OceanBaseErrors { @@ -7,17 +11,36 @@ public final class OceanBaseErrors { private OceanBaseErrors() { } - public static void addExpressionErrors(ExpectedErrors errors) { + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("BIGINT value is out of range"); // e.g., CAST(-('-1e500') AS SIGNED) + errors.add("value is out of range"); errors.add("is not valid for CHARACTER SET"); errors.add("The observer or zone is not the master"); errors.add("Incorrect integer value"); errors.add("Truncated incorrect DOUBLE value"); errors.add("Invalid numeric"); errors.add("Data truncated for argument"); + errors.add("Data truncated for column"); + + return errors; } - public static void addInsertErrors(ExpectedErrors errors) { + public static List getExpressionErrorsRegex() { + ArrayList errors = new ArrayList<>(); + errors.add(Pattern.compile("Unknown column '.+' in 'order clause'")); + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("Duplicate entry"); errors.add("cannot be null"); errors.add("doesn't have a default value"); @@ -42,5 +65,10 @@ public static void addInsertErrors(ExpectedErrors errors) { errors.add("Invalid numeric"); errors.add("Miss column"); + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); } } diff --git a/src/sqlancer/oceanbase/OceanBaseGlobalState.java b/src/sqlancer/oceanbase/OceanBaseGlobalState.java index b8baf08fb..fa483072b 100644 --- a/src/sqlancer/oceanbase/OceanBaseGlobalState.java +++ b/src/sqlancer/oceanbase/OceanBaseGlobalState.java @@ -4,7 +4,6 @@ import java.sql.SQLException; import sqlancer.SQLGlobalState; -import sqlancer.oceanbase.OceanBaseOptions.OceanBaseOracleFactory; public class OceanBaseGlobalState extends SQLGlobalState { diff --git a/src/sqlancer/oceanbase/OceanBaseOptions.java b/src/sqlancer/oceanbase/OceanBaseOptions.java index b19de3983..949e5158b 100644 --- a/src/sqlancer/oceanbase/OceanBaseOptions.java +++ b/src/sqlancer/oceanbase/OceanBaseOptions.java @@ -1,6 +1,5 @@ package sqlancer.oceanbase; -import java.sql.SQLException; import java.util.Arrays; import java.util.List; @@ -8,12 +7,6 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.TestOracle; -import sqlancer.oceanbase.OceanBaseOptions.OceanBaseOracleFactory; -import sqlancer.oceanbase.oracle.OceanBaseNoRECOracle; -import sqlancer.oceanbase.oracle.OceanBasePivotedQuerySynthesisOracle; -import sqlancer.oceanbase.oracle.OceanBaseTLPWhereOracle; @Parameters(separators = "=", commandDescription = "OceanBase (default port: " + OceanBaseOptions.DEFAULT_PORT + ", default host: " + OceanBaseOptions.DEFAULT_HOST + ")") @@ -24,34 +17,6 @@ public class OceanBaseOptions implements DBMSSpecificOptions oracles = Arrays.asList(OceanBaseOracleFactory.TLP_WHERE); - public enum OceanBaseOracleFactory implements OracleFactory { - - TLP_WHERE { - @Override - public TestOracle create(OceanBaseGlobalState globalState) throws SQLException { - return new OceanBaseTLPWhereOracle(globalState); - } - }, - NoREC { - @Override - public TestOracle create(OceanBaseGlobalState globalState) throws SQLException { - return new OceanBaseNoRECOracle(globalState); - } - }, - PQS { - - @Override - public TestOracle create(OceanBaseGlobalState globalState) throws SQLException { - return new OceanBasePivotedQuerySynthesisOracle(globalState); - } - - @Override - public boolean requiresAllTablesToContainRows() { - return true; - } - } - } - @Parameter(names = { "--query-timeout" }, description = "Query timeout") public int queryTimeout = 1000000000; @Parameter(names = { "--transaction-timeout" }, description = "Transaction timeout") diff --git a/src/sqlancer/oceanbase/OceanBaseOracleFactory.java b/src/sqlancer/oceanbase/OceanBaseOracleFactory.java new file mode 100644 index 000000000..b1ab1cb5b --- /dev/null +++ b/src/sqlancer/oceanbase/OceanBaseOracleFactory.java @@ -0,0 +1,50 @@ +package sqlancer.oceanbase; + +import java.sql.SQLException; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.oceanbase.gen.OceanBaseExpressionGenerator; +import sqlancer.oceanbase.oracle.OceanBasePivotedQuerySynthesisOracle; + +public enum OceanBaseOracleFactory implements OracleFactory { + + TLP_WHERE { + @Override + public TestOracle create(OceanBaseGlobalState globalState) throws SQLException { + OceanBaseExpressionGenerator gen = new OceanBaseExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(OceanBaseErrors.getExpressionErrors()) + .withRegex(OceanBaseErrors.getExpressionErrorsRegex()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + NoREC { + @Override + public TestOracle create(OceanBaseGlobalState globalState) throws SQLException { + OceanBaseExpressionGenerator gen = new OceanBaseExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(OceanBaseErrors.getExpressionErrors()) + .withRegex(OceanBaseErrors.getExpressionErrorsRegex()) + .with("canceling statement due to statement timeout").with("unmatched parentheses") + .with("nothing to repeat at offset").with("missing )").with("missing terminating ]") + .with("range out of order in character class").with("unrecognized character after ") + .with("Got error '(*VERB) not recognized or malformed").with("must be followed by") + .with("malformed number or name after").with("digit expected after").build(); + return new NoRECOracle<>(globalState, gen, errors); + } + }, + PQS { + @Override + public TestOracle create(OceanBaseGlobalState globalState) throws SQLException { + return new OceanBasePivotedQuerySynthesisOracle(globalState); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + } +} diff --git a/src/sqlancer/oceanbase/OceanBaseSchema.java b/src/sqlancer/oceanbase/OceanBaseSchema.java index 7de2457ee..7b5b5954f 100644 --- a/src/sqlancer/oceanbase/OceanBaseSchema.java +++ b/src/sqlancer/oceanbase/OceanBaseSchema.java @@ -81,6 +81,7 @@ public int getPrecision() { return precision; } + @Override public boolean isPrimaryKey() { return isPrimaryKey; } @@ -195,10 +196,6 @@ public OceanBaseTable(String tableName, List columns, List c.isPrimaryKey()); - } - } public static final class OceanBaseIndex extends TableIndex { diff --git a/src/sqlancer/oceanbase/OceanBaseToStringVisitor.java b/src/sqlancer/oceanbase/OceanBaseToStringVisitor.java index 56fef98f1..ba306ff90 100644 --- a/src/sqlancer/oceanbase/OceanBaseToStringVisitor.java +++ b/src/sqlancer/oceanbase/OceanBaseToStringVisitor.java @@ -56,7 +56,7 @@ public void visit(OceanBaseSelect s) { throw new AssertionError(); } sb.append(s.getModifiers().stream().collect(Collectors.joining(" "))); - if (s.getModifiers().size() > 0) { + if (!s.getModifiers().isEmpty()) { sb.append(" "); } if (s.getFetchColumns() == null) { @@ -85,7 +85,7 @@ public void visit(OceanBaseSelect s) { sb.append(" WHERE "); visit(whereClause); } - if (s.getGroupByExpressions() != null && s.getGroupByExpressions().size() > 0) { + if (s.getGroupByExpressions() != null && !s.getGroupByExpressions().isEmpty()) { sb.append(" "); sb.append("GROUP BY "); List groupBys = s.getGroupByExpressions(); @@ -100,14 +100,14 @@ public void visit(OceanBaseSelect s) { sb.append(" HAVING "); visit(s.getHavingClause()); } - if (!s.getOrderByExpressions().isEmpty()) { + if (!s.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - List orderBys = s.getOrderByExpressions(); + List orderBys = s.getOrderByClauses(); for (int i = 0; i < orderBys.size(); i++) { if (i != 0) { sb.append(", "); } - visit(s.getOrderByExpressions().get(i)); + visit(s.getOrderByClauses().get(i)); } } if (s.getLimitClause() != null) { @@ -260,7 +260,7 @@ public void visit(OceanBaseStringExpression op) { sb.append(op.getStr()); } else { String str = op.getStr(); - if (str.length() > 0) { + if (!str.isEmpty()) { sb.append(r.getInteger(0, 100000)); } else { sb.append(r.getInteger(0, 1000000)); diff --git a/src/sqlancer/oceanbase/ast/OceanBaseConstant.java b/src/sqlancer/oceanbase/ast/OceanBaseConstant.java index 91a481fc6..84dadb9e4 100644 --- a/src/sqlancer/oceanbase/ast/OceanBaseConstant.java +++ b/src/sqlancer/oceanbase/ast/OceanBaseConstant.java @@ -194,7 +194,7 @@ public boolean isNull() { @Override public boolean isEmpty() { // "" " " - if (value.length() == 0) { + if (value.isEmpty()) { return true; } else { for (int i = 0; i < value.length(); i++) { diff --git a/src/sqlancer/oceanbase/ast/OceanBaseExpression.java b/src/sqlancer/oceanbase/ast/OceanBaseExpression.java index 195a05967..b721d711f 100644 --- a/src/sqlancer/oceanbase/ast/OceanBaseExpression.java +++ b/src/sqlancer/oceanbase/ast/OceanBaseExpression.java @@ -1,6 +1,9 @@ package sqlancer.oceanbase.ast; -public interface OceanBaseExpression { +import sqlancer.common.ast.newast.Expression; +import sqlancer.oceanbase.OceanBaseSchema.OceanBaseColumn; + +public interface OceanBaseExpression extends Expression { default OceanBaseConstant getExpectedValue() { throw new AssertionError("PQS not supported for this operator"); diff --git a/src/sqlancer/oceanbase/ast/OceanBaseJoin.java b/src/sqlancer/oceanbase/ast/OceanBaseJoin.java index 9e8271bbf..4855c0808 100644 --- a/src/sqlancer/oceanbase/ast/OceanBaseJoin.java +++ b/src/sqlancer/oceanbase/ast/OceanBaseJoin.java @@ -1,10 +1,17 @@ package sqlancer.oceanbase.ast; -public class OceanBaseJoin implements OceanBaseExpression { +import sqlancer.common.ast.newast.Join; +import sqlancer.oceanbase.OceanBaseSchema.OceanBaseColumn; +import sqlancer.oceanbase.OceanBaseSchema.OceanBaseTable; + +public class OceanBaseJoin implements OceanBaseExpression, Join { @Override public OceanBaseConstant getExpectedValue() { throw new UnsupportedOperationException(); } + @Override + public void setOnClause(OceanBaseExpression onClause) { + } } diff --git a/src/sqlancer/oceanbase/ast/OceanBaseSelect.java b/src/sqlancer/oceanbase/ast/OceanBaseSelect.java index a87ff5c27..c4bbc1b87 100644 --- a/src/sqlancer/oceanbase/ast/OceanBaseSelect.java +++ b/src/sqlancer/oceanbase/ast/OceanBaseSelect.java @@ -5,8 +5,13 @@ import java.util.List; import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.oceanbase.OceanBaseSchema.OceanBaseColumn; +import sqlancer.oceanbase.OceanBaseSchema.OceanBaseTable; +import sqlancer.oceanbase.OceanBaseVisitor; -public class OceanBaseSelect extends SelectBase implements OceanBaseExpression { +public class OceanBaseSelect extends SelectBase + implements OceanBaseExpression, Select { private SelectType fromOptions = SelectType.ALL; private List modifiers = Collections.emptyList(); @@ -29,10 +34,12 @@ public void setFromOptions(SelectType fromOptions) { this.fromOptions = fromOptions; } + @Override public void setGroupByClause(List groupBys) { this.groupBys = groupBys; } + @Override public List getGroupByClause() { return this.groupBys; } @@ -58,4 +65,18 @@ public OceanBaseStringExpression getHint() { return hint; } + @Override + public void setJoinClauses(List joinStatements) { + } + + @Override + public List getJoinClauses() { + return List.of(); + } + + @Override + public String asString() { + return OceanBaseVisitor.asString(this); + } + } diff --git a/src/sqlancer/oceanbase/gen/OceanBaseAlterTable.java b/src/sqlancer/oceanbase/gen/OceanBaseAlterTable.java index e67256448..7e1a69dfc 100644 --- a/src/sqlancer/oceanbase/gen/OceanBaseAlterTable.java +++ b/src/sqlancer/oceanbase/gen/OceanBaseAlterTable.java @@ -56,7 +56,7 @@ private SQLQueryAdapter create() { case COMPRESSION: sb.append("COMPRESSION "); sb.append("'"); - sb.append(Randomly.fromOptions("ZLIB_1.0", "LZ4_1.0", "NONE")); + sb.append(Randomly.fromOptions("LZ4_1.0", "NONE")); sb.append("'"); break; default: diff --git a/src/sqlancer/oceanbase/gen/OceanBaseDeleteGenerator.java b/src/sqlancer/oceanbase/gen/OceanBaseDeleteGenerator.java index f2a6b2731..ea1cb36e8 100644 --- a/src/sqlancer/oceanbase/gen/OceanBaseDeleteGenerator.java +++ b/src/sqlancer/oceanbase/gen/OceanBaseDeleteGenerator.java @@ -44,7 +44,7 @@ private SQLQueryAdapter generate() { errors.addAll(Arrays.asList("doesn't have this option", "Truncated incorrect DOUBLE value", "Truncated incorrect INTEGER value", "Truncated incorrect DECIMAL value", "Data truncated for functional index", "Incorrect value", "Out of range value for column", - "Data truncation: %s value is out of range in '%s'")); + "Data truncation:")); return new SQLQueryAdapter(sb.toString(), errors); } diff --git a/src/sqlancer/oceanbase/gen/OceanBaseExpressionGenerator.java b/src/sqlancer/oceanbase/gen/OceanBaseExpressionGenerator.java index f4f573698..42b144c93 100644 --- a/src/sqlancer/oceanbase/gen/OceanBaseExpressionGenerator.java +++ b/src/sqlancer/oceanbase/gen/OceanBaseExpressionGenerator.java @@ -3,13 +3,21 @@ import java.sql.Connection; import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; import sqlancer.Randomly; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; import sqlancer.oceanbase.OceanBaseGlobalState; import sqlancer.oceanbase.OceanBaseSchema; import sqlancer.oceanbase.OceanBaseSchema.OceanBaseColumn; +import sqlancer.oceanbase.OceanBaseSchema.OceanBaseDataType; import sqlancer.oceanbase.OceanBaseSchema.OceanBaseRowValue; +import sqlancer.oceanbase.OceanBaseSchema.OceanBaseTable; +import sqlancer.oceanbase.ast.OceanBaseAggregate; +import sqlancer.oceanbase.ast.OceanBaseAggregate.OceanBaseAggregateFunction; import sqlancer.oceanbase.ast.OceanBaseBinaryComparisonOperation; import sqlancer.oceanbase.ast.OceanBaseBinaryComparisonOperation.BinaryComparisonOperator; import sqlancer.oceanbase.ast.OceanBaseBinaryLogicalOperation; @@ -23,15 +31,22 @@ import sqlancer.oceanbase.ast.OceanBaseExists; import sqlancer.oceanbase.ast.OceanBaseExpression; import sqlancer.oceanbase.ast.OceanBaseInOperation; +import sqlancer.oceanbase.ast.OceanBaseJoin; +import sqlancer.oceanbase.ast.OceanBaseSelect; import sqlancer.oceanbase.ast.OceanBaseStringExpression; +import sqlancer.oceanbase.ast.OceanBaseTableReference; +import sqlancer.oceanbase.ast.OceanBaseText; import sqlancer.oceanbase.ast.OceanBaseUnaryPostfixOperation; import sqlancer.oceanbase.ast.OceanBaseUnaryPrefixOperation; import sqlancer.oceanbase.ast.OceanBaseUnaryPrefixOperation.OceanBaseUnaryPrefixOperator; -public class OceanBaseExpressionGenerator extends UntypedExpressionGenerator { +public class OceanBaseExpressionGenerator extends UntypedExpressionGenerator + implements NoRECGenerator, + TLPWhereGenerator { private OceanBaseGlobalState state; private OceanBaseRowValue rowVal; + private List tables; public OceanBaseExpressionGenerator(OceanBaseGlobalState state) { this.state = state; @@ -209,4 +224,145 @@ public OceanBaseExpression isNull(OceanBaseExpression expr) { return new OceanBaseUnaryPostfixOperation(expr, OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_NULL, false); } + + @Override + public OceanBaseExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public OceanBaseExpression generateBooleanExpression() { + return generateExpression(); + } + + @Override + public OceanBaseSelect generateSelect() { + return new OceanBaseSelect(); + } + + @Override + public List getRandomJoinClauses() { + return List.of(); + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new OceanBaseTableReference(t)).collect(Collectors.toList()); + } + + @Override + public String generateOptimizedQueryString(OceanBaseSelect select, OceanBaseExpression whereCondition, + boolean shouldUseAggregate) { + if (shouldUseAggregate) { + OceanBaseExpression aggr = new OceanBaseAggregate( + new OceanBaseColumnReference(new OceanBaseColumn("*", OceanBaseDataType.INT, false, 0, false), + null), + OceanBaseAggregateFunction.COUNT); + select.setFetchColumns(List.of(aggr)); + } else { + List allColumns = columns.stream().map((c) -> new OceanBaseColumnReference(c, null)) + .collect(Collectors.toList()); + select.setFetchColumns(allColumns); + if (Randomly.getBooleanWithSmallProbability()) { + select.setOrderByClauses(generateOrderBys()); + } + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(OceanBaseSelect select, OceanBaseExpression whereCondition) { + OceanBaseExpression expr = getTrueExpr(whereCondition); + + OceanBaseText asText = new OceanBaseText(expr, " as count", false); + select.setFetchColumns(List.of(asText)); + select.setSelectType(OceanBaseSelect.SelectType.ALL); + + return "SELECT SUM(count) FROM (" + select.asString() + ") as asdf"; + } + + private enum Option { + TRUE, FALSE_NULL, NOT_NOT_TRUE, NOT_FALSE_NOT_NULL, IF, IFNULL, COALESCE + }; + + private OceanBaseExpression getTrueExpr(OceanBaseExpression randomWhereCondition) { + // we can treat "is true" as combinations of "is flase" and "not","is not true" and "not",etc. + OceanBaseUnaryPostfixOperation isTrue = new OceanBaseUnaryPostfixOperation(randomWhereCondition, + OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_TRUE, false); + + OceanBaseUnaryPostfixOperation isFalse = new OceanBaseUnaryPostfixOperation(randomWhereCondition, + OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_FALSE, false); + + OceanBaseUnaryPostfixOperation isNotFalse = new OceanBaseUnaryPostfixOperation(randomWhereCondition, + OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_FALSE, true); + + OceanBaseUnaryPostfixOperation isNULL = new OceanBaseUnaryPostfixOperation(randomWhereCondition, + OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_NULL, false); + + OceanBaseUnaryPostfixOperation isNotNULL = new OceanBaseUnaryPostfixOperation(randomWhereCondition, + OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_NULL, true); + + OceanBaseExpression expr = OceanBaseConstant.createNullConstant(); + Option a = Randomly.fromOptions(Option.values()); + switch (a) { + case TRUE: + expr = isTrue; + break; + case FALSE_NULL: + // not((is false) or (is null)) + expr = new OceanBaseUnaryPrefixOperation( + new OceanBaseBinaryLogicalOperation(isFalse, isNULL, + OceanBaseBinaryLogicalOperation.OceanBaseBinaryLogicalOperator.OR), + OceanBaseUnaryPrefixOperation.OceanBaseUnaryPrefixOperator.NOT); + break; + case NOT_NOT_TRUE: + // not(not(is true))) + expr = new OceanBaseUnaryPrefixOperation( + new OceanBaseUnaryPrefixOperation(isTrue, + OceanBaseUnaryPrefixOperation.OceanBaseUnaryPrefixOperator.NOT), + OceanBaseUnaryPrefixOperation.OceanBaseUnaryPrefixOperator.NOT); + break; + case NOT_FALSE_NOT_NULL: + // (is not false) and (is not null) + expr = new OceanBaseBinaryLogicalOperation(isNotFalse, isNotNULL, + OceanBaseBinaryLogicalOperation.OceanBaseBinaryLogicalOperator.AND); + break; + case IF: + // if(1, xx is true, 0) + OceanBaseExpression[] args = new OceanBaseExpression[3]; + args[0] = OceanBaseConstant.createIntConstant(1); + args[1] = isTrue; + args[2] = OceanBaseConstant.createIntConstant(0); + expr = new OceanBaseComputableFunction(OceanBaseFunction.IF, args); + break; + case IFNULL: + // ifnull(null, xx is true) + OceanBaseExpression[] ifArgs = new OceanBaseExpression[2]; + ifArgs[0] = OceanBaseConstant.createNullConstant(); + ifArgs[1] = isTrue; + expr = new OceanBaseComputableFunction(OceanBaseFunction.IFNULL, ifArgs); + break; + case COALESCE: + // coalesce(null, xx is true) + OceanBaseExpression[] coalesceArgs = new OceanBaseExpression[2]; + coalesceArgs[0] = OceanBaseConstant.createNullConstant(); + coalesceArgs[1] = isTrue; + expr = new OceanBaseComputableFunction(OceanBaseFunction.COALESCE, coalesceArgs); + break; + default: + expr = isTrue; + break; + } + return expr; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + return columns.stream().map(c -> new OceanBaseColumnReference(c, null)).collect(Collectors.toList()); + } } diff --git a/src/sqlancer/oceanbase/gen/OceanBaseUpdateGenerator.java b/src/sqlancer/oceanbase/gen/OceanBaseUpdateGenerator.java index f89434498..950317bd2 100644 --- a/src/sqlancer/oceanbase/gen/OceanBaseUpdateGenerator.java +++ b/src/sqlancer/oceanbase/gen/OceanBaseUpdateGenerator.java @@ -3,17 +3,18 @@ import java.util.List; import sqlancer.Randomly; -import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.gen.AbstractUpdateGenerator; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.oceanbase.OceanBaseErrors; import sqlancer.oceanbase.OceanBaseGlobalState; import sqlancer.oceanbase.OceanBaseSchema; +import sqlancer.oceanbase.OceanBaseSchema.OceanBaseColumn; import sqlancer.oceanbase.OceanBaseVisitor; -public class OceanBaseUpdateGenerator { +public class OceanBaseUpdateGenerator extends AbstractUpdateGenerator { - private final StringBuilder sb = new StringBuilder(); private final OceanBaseGlobalState globalState; + private OceanBaseExpressionGenerator gen; private final Randomly r; public OceanBaseUpdateGenerator(OceanBaseGlobalState globalState) { @@ -26,29 +27,16 @@ public static SQLQueryAdapter update(OceanBaseGlobalState globalState) { } private SQLQueryAdapter generate() { - ExpectedErrors errors = new ExpectedErrors(); OceanBaseSchema.OceanBaseTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); - OceanBaseExpressionGenerator gen = new OceanBaseExpressionGenerator(globalState).setColumns(table.getColumns()); + List columns = table.getRandomNonEmptyColumnSubset(); + gen = new OceanBaseExpressionGenerator(globalState).setColumns(table.getColumns()); sb.append("UPDATE "); if (Randomly.getBoolean()) { sb.append(" /*+parallel(" + r.getInteger(0, 10) + ") enable_parallel_dml*/ "); } sb.append(table.getName()); sb.append(" SET "); - List columns = table.getRandomNonEmptyColumnSubset(); - for (int i = 0; i < columns.size(); i++) { - if (i != 0) { - sb.append(", "); - } - sb.append(columns.get(i).getName()); - sb.append("="); - if (Randomly.getBoolean()) { - sb.append(gen.generateConstant(columns.get(i))); - } else { - sb.append(OceanBaseVisitor.asString(gen.generateExpression())); - OceanBaseErrors.addExpressionErrors(errors); - } - } + updateColumns(columns); if (Randomly.getBoolean()) { sb.append(" WHERE "); OceanBaseErrors.addExpressionErrors(errors); @@ -60,4 +48,14 @@ private SQLQueryAdapter generate() { return new SQLQueryAdapter(sb.toString(), errors); } + + @Override + protected void updateValue(OceanBaseColumn column) { + if (Randomly.getBoolean()) { + sb.append(gen.generateConstant(column)); + } else { + sb.append(OceanBaseVisitor.asString(gen.generateExpression())); + OceanBaseErrors.addExpressionErrors(errors); + } + } } diff --git a/src/sqlancer/oceanbase/oracle/OceanBaseNoRECOracle.java b/src/sqlancer/oceanbase/oracle/OceanBaseNoRECOracle.java deleted file mode 100644 index 900ac8caa..000000000 --- a/src/sqlancer/oceanbase/oracle/OceanBaseNoRECOracle.java +++ /dev/null @@ -1,224 +0,0 @@ -package sqlancer.oceanbase.oracle; - -import java.sql.SQLException; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; - -import sqlancer.Randomly; -import sqlancer.common.oracle.NoRECBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.common.query.SQLQueryAdapter; -import sqlancer.common.query.SQLancerResultSet; -import sqlancer.oceanbase.OceanBaseGlobalState; -import sqlancer.oceanbase.OceanBaseSchema; -import sqlancer.oceanbase.OceanBaseVisitor; -import sqlancer.oceanbase.ast.OceanBaseAggregate; -import sqlancer.oceanbase.ast.OceanBaseBinaryLogicalOperation; -import sqlancer.oceanbase.ast.OceanBaseColumnName; -import sqlancer.oceanbase.ast.OceanBaseComputableFunction; -import sqlancer.oceanbase.ast.OceanBaseComputableFunction.OceanBaseFunction; -import sqlancer.oceanbase.ast.OceanBaseConstant; -import sqlancer.oceanbase.ast.OceanBaseExpression; -import sqlancer.oceanbase.ast.OceanBaseSelect; -import sqlancer.oceanbase.ast.OceanBaseTableReference; -import sqlancer.oceanbase.ast.OceanBaseText; -import sqlancer.oceanbase.ast.OceanBaseUnaryPostfixOperation; -import sqlancer.oceanbase.ast.OceanBaseUnaryPrefixOperation; -import sqlancer.oceanbase.gen.OceanBaseExpressionGenerator; - -public class OceanBaseNoRECOracle extends NoRECBase implements TestOracle { - - // SELECT COUNT(*) FROM t0 WHERE ; - // SELECT SUM(count) FROM (SELECT IS TRUE as count FROM t0); - // SELECT (SELECT COUNT(*) FROM t0 WHERE c0 IS NOT 0) = (SELECT COUNT(*) FROM - // (SELECT c0 is NOT 0 FROM t0)); - private final OceanBaseSchema s; - private String firstQueryString; - private static final int NOT_FOUND = -1; - - private enum Option { - TRUE, FALSE_NULL, NOT_NOT_TRUE, NOT_FALSE_NOT_NULL, IF, IFNULL, COALESCE - }; - - public OceanBaseNoRECOracle(OceanBaseGlobalState globalState) { - super(globalState); - this.s = globalState.getSchema(); - errors.add("is out of range"); - // regex - errors.add("unmatched parentheses"); - errors.add("nothing to repeat at offset"); - errors.add("missing )"); - errors.add("missing terminating ]"); - errors.add("range out of order in character class"); - errors.add("unrecognized character after "); - errors.add("Got error '(*VERB) not recognized or malformed"); - errors.add("must be followed by"); - errors.add("malformed number or name after"); - errors.add("digit expected after"); - } - - @Override - public void check() throws SQLException { - OceanBaseSchema.OceanBaseTable randomTable = s.getRandomTable(); - List columns = randomTable.getColumns(); - OceanBaseExpressionGenerator gen = new OceanBaseExpressionGenerator(state).setColumns(columns); - OceanBaseExpression randomWhereCondition = gen.generateExpression(); - List groupBys = Collections.emptyList(); // getRandomExpressions(columns); - List tableList = Arrays.asList(randomTable).stream() - .map(t -> new OceanBaseTableReference(t)).collect(Collectors.toList()); - int firstCount = getFirstQueryCount(tableList, randomWhereCondition, groupBys); - int secondCount = getSecondQuery(tableList, randomWhereCondition, groupBys); - if (firstCount != secondCount && firstCount != NOT_FOUND && secondCount != NOT_FOUND) { - String queryFormatString = "-- %s;\n-- count: %d"; - String firstQueryStringWithCount = String.format(queryFormatString, optimizedQueryString, firstCount); - String secondQueryStringWithCount = String.format(queryFormatString, unoptimizedQueryString, secondCount); - state.getState().getLocalState() - .log(String.format("%s\n%s", firstQueryStringWithCount, secondQueryStringWithCount)); - String assertionMessage = String.format("the counts mismatch (%d and %d)!\n%s\n%s", firstCount, secondCount, - firstQueryStringWithCount, secondQueryStringWithCount); - throw new AssertionError(assertionMessage); - } - } - - private int getSecondQuery(List tableList, OceanBaseExpression randomWhereCondition, - List groupBys) throws SQLException { - OceanBaseSelect select = new OceanBaseSelect(); - select.setGroupByClause(groupBys); - OceanBaseExpression expr = getTrueExpr(randomWhereCondition); - - OceanBaseText asText = new OceanBaseText(expr, " as count", false); - select.setFetchColumns(Arrays.asList(asText)); - select.setFromList(tableList); - select.setSelectType(OceanBaseSelect.SelectType.ALL); - int secondCount = 0; - - unoptimizedQueryString = "SELECT SUM(count) FROM (" + OceanBaseVisitor.asString(select) + ") as asdf"; - SQLQueryAdapter q = new SQLQueryAdapter(unoptimizedQueryString, errors); - SQLancerResultSet rs; - if (options.logEachSelect()) { - logger.writeCurrent(unoptimizedQueryString); - } - try { - rs = q.executeAndGet(state); - } catch (Exception e) { - throw new AssertionError(optimizedQueryString, e); - } - if (rs == null) { - return -1; - } - if (rs.next()) { - secondCount += rs.getLong(1); - } - rs.close(); - return secondCount; - } - - private int getFirstQueryCount(List tableList, OceanBaseExpression randomWhereCondition, - List groupBys) throws SQLException { - OceanBaseSelect select = new OceanBaseSelect(); - select.setGroupByClause(groupBys); - // SELECT COUNT(t1.c3) FROM t1 WHERE (- (t1.c2)); - // SELECT SUM(count) FROM (SELECT ((- (t1.c2)) IS TRUE) as count FROM t1);; - OceanBaseAggregate aggr = new OceanBaseAggregate(new OceanBaseColumnName( - new OceanBaseSchema.OceanBaseColumn("*", OceanBaseSchema.OceanBaseDataType.INT, false, 0, false)), - OceanBaseAggregate.OceanBaseAggregateFunction.COUNT); - select.setFetchColumns(Arrays.asList(aggr)); - select.setFromList(tableList); - select.setWhereClause(randomWhereCondition); - select.setSelectType(OceanBaseSelect.SelectType.ALL); - int firstCount = 0; - optimizedQueryString = OceanBaseVisitor.asString(select); - SQLQueryAdapter q = new SQLQueryAdapter(optimizedQueryString, errors); - SQLancerResultSet rs; - if (options.logEachSelect()) { - logger.writeCurrent(optimizedQueryString); - } - try { - rs = q.executeAndGet(state); - } catch (Exception e) { - throw new AssertionError(firstQueryString, e); - } - if (rs == null) { - return -1; - } - if (rs.next()) { - firstCount += rs.getLong(1); - } - rs.close(); - return firstCount; - } - - private OceanBaseExpression getTrueExpr(OceanBaseExpression randomWhereCondition) { - // we can treat "is true" as combinations of "is flase" and "not","is not true" and "not",etc. - OceanBaseUnaryPostfixOperation isTrue = new OceanBaseUnaryPostfixOperation(randomWhereCondition, - OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_TRUE, false); - - OceanBaseUnaryPostfixOperation isFalse = new OceanBaseUnaryPostfixOperation(randomWhereCondition, - OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_FALSE, false); - - OceanBaseUnaryPostfixOperation isNotFalse = new OceanBaseUnaryPostfixOperation(randomWhereCondition, - OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_FALSE, true); - - OceanBaseUnaryPostfixOperation isNULL = new OceanBaseUnaryPostfixOperation(randomWhereCondition, - OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_NULL, false); - - OceanBaseUnaryPostfixOperation isNotNULL = new OceanBaseUnaryPostfixOperation(randomWhereCondition, - OceanBaseUnaryPostfixOperation.UnaryPostfixOperator.IS_NULL, true); - - OceanBaseExpression expr = OceanBaseConstant.createNullConstant(); - Option a = Randomly.fromOptions(Option.values()); - switch (a) { - case TRUE: - expr = isTrue; - break; - case FALSE_NULL: - // not((is false) or (is null)) - expr = new OceanBaseUnaryPrefixOperation( - new OceanBaseBinaryLogicalOperation(isFalse, isNULL, - OceanBaseBinaryLogicalOperation.OceanBaseBinaryLogicalOperator.OR), - OceanBaseUnaryPrefixOperation.OceanBaseUnaryPrefixOperator.NOT); - break; - case NOT_NOT_TRUE: - // not(not(is true))) - expr = new OceanBaseUnaryPrefixOperation( - new OceanBaseUnaryPrefixOperation(isTrue, - OceanBaseUnaryPrefixOperation.OceanBaseUnaryPrefixOperator.NOT), - OceanBaseUnaryPrefixOperation.OceanBaseUnaryPrefixOperator.NOT); - break; - case NOT_FALSE_NOT_NULL: - // (is not false) and (is not null) - expr = new OceanBaseBinaryLogicalOperation(isNotFalse, isNotNULL, - OceanBaseBinaryLogicalOperation.OceanBaseBinaryLogicalOperator.AND); - break; - case IF: - // if(1, xx is true, 0) - OceanBaseExpression[] args = new OceanBaseExpression[3]; - args[0] = OceanBaseConstant.createIntConstant(1); - args[1] = isTrue; - args[2] = OceanBaseConstant.createIntConstant(0); - expr = new OceanBaseComputableFunction(OceanBaseFunction.IF, args); - break; - case IFNULL: - // ifnull(null, xx is true) - OceanBaseExpression[] ifArgs = new OceanBaseExpression[2]; - ifArgs[0] = OceanBaseConstant.createNullConstant(); - ifArgs[1] = isTrue; - expr = new OceanBaseComputableFunction(OceanBaseFunction.IFNULL, ifArgs); - break; - case COALESCE: - // coalesce(null, xx is true) - OceanBaseExpression[] coalesceArgs = new OceanBaseExpression[2]; - coalesceArgs[0] = OceanBaseConstant.createNullConstant(); - coalesceArgs[1] = isTrue; - expr = new OceanBaseComputableFunction(OceanBaseFunction.COALESCE, coalesceArgs); - break; - default: - expr = isTrue; - break; - } - return expr; - } - -} diff --git a/src/sqlancer/oceanbase/oracle/OceanBasePivotedQuerySynthesisOracle.java b/src/sqlancer/oceanbase/oracle/OceanBasePivotedQuerySynthesisOracle.java index 48fe1002e..8fa0a4bc0 100644 --- a/src/sqlancer/oceanbase/oracle/OceanBasePivotedQuerySynthesisOracle.java +++ b/src/sqlancer/oceanbase/oracle/OceanBasePivotedQuerySynthesisOracle.java @@ -71,7 +71,7 @@ public Query getRectifiedQuery() throws SQLException { selectStatement.setOffsetClause(offsetClause); } List orderBy = generateOrderBy(columns); - selectStatement.setOrderByExpressions(orderBy); + selectStatement.setOrderByClauses(orderBy); return new SQLQueryAdapter(OceanBaseVisitor.asString(selectStatement), errors); } diff --git a/src/sqlancer/oceanbase/oracle/OceanBaseTLPBase.java b/src/sqlancer/oceanbase/oracle/OceanBaseTLPBase.java deleted file mode 100644 index 51f06a3fc..000000000 --- a/src/sqlancer/oceanbase/oracle/OceanBaseTLPBase.java +++ /dev/null @@ -1,62 +0,0 @@ -package sqlancer.oceanbase.oracle; - -import java.sql.SQLException; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import sqlancer.common.gen.ExpressionGenerator; -import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.oceanbase.OceanBaseErrors; -import sqlancer.oceanbase.OceanBaseGlobalState; -import sqlancer.oceanbase.OceanBaseSchema; -import sqlancer.oceanbase.OceanBaseSchema.OceanBaseTable; -import sqlancer.oceanbase.OceanBaseSchema.OceanBaseTables; -import sqlancer.oceanbase.ast.OceanBaseColumnReference; -import sqlancer.oceanbase.ast.OceanBaseExpression; -import sqlancer.oceanbase.ast.OceanBaseSelect; -import sqlancer.oceanbase.ast.OceanBaseTableReference; -import sqlancer.oceanbase.gen.OceanBaseExpressionGenerator; -import sqlancer.oceanbase.gen.OceanBaseHintGenerator; - -public abstract class OceanBaseTLPBase - extends TernaryLogicPartitioningOracleBase implements TestOracle { - - OceanBaseSchema s; - OceanBaseTables targetTables; - OceanBaseExpressionGenerator gen; - OceanBaseSelect select; - - public OceanBaseTLPBase(OceanBaseGlobalState state) { - super(state); - OceanBaseErrors.addExpressionErrors(errors); - errors.add("value is out of range"); - } - - @Override - public void check() throws SQLException { - s = state.getSchema(); - targetTables = s.getRandomTableNonEmptyTables(); - gen = new OceanBaseExpressionGenerator(state).setColumns(targetTables.getColumns()); - initializeTernaryPredicateVariants(); - select = new OceanBaseSelect(); - select.setFetchColumns(generateFetchColumns()); - List tables = targetTables.getTables(); - OceanBaseHintGenerator.generateHints(select, tables); - List tableList = tables.stream().map(t -> new OceanBaseTableReference(t)) - .collect(Collectors.toList()); - select.setFromList(tableList); - select.setWhereClause(null); - } - - List generateFetchColumns() { - return Arrays.asList(OceanBaseColumnReference.create(targetTables.getColumns().get(0), null)); - } - - @Override - protected ExpressionGenerator getGen() { - return gen; - } - -} diff --git a/src/sqlancer/oceanbase/oracle/OceanBaseTLPWhereOracle.java b/src/sqlancer/oceanbase/oracle/OceanBaseTLPWhereOracle.java deleted file mode 100644 index 73a0c3a6f..000000000 --- a/src/sqlancer/oceanbase/oracle/OceanBaseTLPWhereOracle.java +++ /dev/null @@ -1,44 +0,0 @@ -package sqlancer.oceanbase.oracle; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.Randomly; -import sqlancer.oceanbase.OceanBaseGlobalState; -import sqlancer.oceanbase.OceanBaseVisitor; - -public class OceanBaseTLPWhereOracle extends OceanBaseTLPBase { - - public OceanBaseTLPWhereOracle(OceanBaseGlobalState state) { - super(state); - } - - @Override - public void check() throws SQLException { - super.check(); - select.setWhereClause(null); - String originalQueryString = OceanBaseVisitor.asString(select); - - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - if (Randomly.getBoolean()) { - select.setOrderByExpressions(gen.generateOrderBys()); - } - select.setOrderByExpressions(Collections.emptyList()); - select.setWhereClause(predicate); - String firstQueryString = OceanBaseVisitor.asString(select); - select.setWhereClause(negatedPredicate); - String secondQueryString = OceanBaseVisitor.asString(select); - select.setWhereClause(isNullPredicate); - String thirdQueryString = OceanBaseVisitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, Randomly.getBoolean(), state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state); - } - -} diff --git a/src/sqlancer/postgres/PostgresBugs.java b/src/sqlancer/postgres/PostgresBugs.java new file mode 100644 index 000000000..352d6c1c5 --- /dev/null +++ b/src/sqlancer/postgres/PostgresBugs.java @@ -0,0 +1,10 @@ +package sqlancer.postgres; + +// do not make the fields final to avoid warnings +public final class PostgresBugs { + public static boolean bug18643 = true; + + private PostgresBugs() { + } + +} diff --git a/src/sqlancer/postgres/PostgresExpectedValueVisitor.java b/src/sqlancer/postgres/PostgresExpectedValueVisitor.java index d75aa4a5e..8985caf0b 100644 --- a/src/sqlancer/postgres/PostgresExpectedValueVisitor.java +++ b/src/sqlancer/postgres/PostgresExpectedValueVisitor.java @@ -1,10 +1,12 @@ package sqlancer.postgres; +import sqlancer.IgnoreMeException; import sqlancer.postgres.ast.PostgresAggregate; import sqlancer.postgres.ast.PostgresBetweenOperation; import sqlancer.postgres.ast.PostgresBinaryLogicalOperation; import sqlancer.postgres.ast.PostgresCastOperation; import sqlancer.postgres.ast.PostgresCollate; +import sqlancer.postgres.ast.PostgresColumnReference; import sqlancer.postgres.ast.PostgresColumnValue; import sqlancer.postgres.ast.PostgresConstant; import sqlancer.postgres.ast.PostgresExpression; @@ -20,6 +22,8 @@ import sqlancer.postgres.ast.PostgresSelect.PostgresFromTable; import sqlancer.postgres.ast.PostgresSelect.PostgresSubquery; import sqlancer.postgres.ast.PostgresSimilarTo; +import sqlancer.postgres.ast.PostgresTableReference; +import sqlancer.postgres.ast.PostgresWindowFunction; public final class PostgresExpectedValueVisitor implements PostgresVisitor { @@ -44,7 +48,7 @@ private void print(PostgresExpression expr) { // try { // super.visit(expr); // } catch (IgnoreMeException e) { - // + // } // nrTabs--; // } @@ -75,6 +79,15 @@ public void visit(PostgresPrefixOperation op) { visit(op.getExpression()); } + @Override + public void visit(PostgresColumnReference column) { + print(column); + } + + @Override + public void visit(PostgresTableReference tb) { + } + @Override public void visit(PostgresSelect op) { visit(op.getWhereClause()); @@ -85,6 +98,11 @@ public void visit(PostgresOrderByTerm op) { } + @Override + public void visit(PostgresWindowFunction windowFunction) { + throw new IgnoreMeException(); + } + @Override public void visit(PostgresFunction f) { print(f); diff --git a/src/sqlancer/postgres/PostgresOptions.java b/src/sqlancer/postgres/PostgresOptions.java index b6c79d752..ebef13a16 100644 --- a/src/sqlancer/postgres/PostgresOptions.java +++ b/src/sqlancer/postgres/PostgresOptions.java @@ -1,7 +1,6 @@ package sqlancer.postgres; -import java.sql.SQLException; -import java.util.ArrayList; +import java.io.File; import java.util.Arrays; import java.util.List; @@ -9,21 +8,13 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.CompositeTestOracle; -import sqlancer.common.oracle.TestOracle; -import sqlancer.postgres.PostgresOptions.PostgresOracleFactory; -import sqlancer.postgres.oracle.PostgresNoRECOracle; -import sqlancer.postgres.oracle.PostgresPivotedQuerySynthesisOracle; -import sqlancer.postgres.oracle.tlp.PostgresTLPAggregateOracle; -import sqlancer.postgres.oracle.tlp.PostgresTLPHavingOracle; -import sqlancer.postgres.oracle.tlp.PostgresTLPWhereOracle; @Parameters(separators = "=", commandDescription = "PostgreSQL (default port: " + PostgresOptions.DEFAULT_PORT + ", default host: " + PostgresOptions.DEFAULT_HOST + ")") public class PostgresOptions implements DBMSSpecificOptions { public static final String DEFAULT_HOST = "localhost"; public static final int DEFAULT_PORT = 5432; + private static Boolean defaultTestTablespaces; @Parameter(names = "--bulk-insert", description = "Specifies whether INSERT statements should be issued in bulk", arity = 1) public boolean allowBulkInsert; @@ -31,9 +22,18 @@ public class PostgresOptions implements DBMSSpecificOptions oracle = Arrays.asList(PostgresOracleFactory.QUERY_PARTITIONING); + @Parameter(names = "--connection-timeout", description = "Timeout in seconds for connecting to the server", arity = 1) + public int connectionTimeoutInSeconds; + @Parameter(names = "--test-collations", description = "Specifies whether to test different collations", arity = 1) public boolean testCollations = true; + @Parameter(names = "--test-tablespaces", description = "Specifies whether to test tablespace creation (default is OS-dependent)", arity = 1) + public boolean testTablespaces; + + @Parameter(names = "--tablespace-path", description = "Base path for tablespace directories (default is OS-dependent)", arity = 1) + public String tablespacePath = getDefaultTablespacePath(); + @Parameter(names = "--connection-url", description = "Specifies the URL for connecting to the PostgreSQL server", arity = 1) public String connectionURL = String.format("postgresql://%s:%d/test", PostgresOptions.DEFAULT_HOST, PostgresOptions.DEFAULT_PORT); @@ -41,43 +41,36 @@ public class PostgresOptions implements DBMSSpecificOptions { - NOREC { - @Override - public TestOracle create(PostgresGlobalState globalState) throws SQLException { - return new PostgresNoRECOracle(globalState); - } - }, - PQS { - @Override - public TestOracle create(PostgresGlobalState globalState) throws SQLException { - return new PostgresPivotedQuerySynthesisOracle(globalState); - } - - @Override - public boolean requiresAllTablesToContainRows() { - return true; - } - }, - HAVING { - - @Override - public TestOracle create(PostgresGlobalState globalState) throws SQLException { - return new PostgresTLPHavingOracle(globalState); - } - - }, - QUERY_PARTITIONING { - @Override - public TestOracle create(PostgresGlobalState globalState) throws SQLException { - List oracles = new ArrayList<>(); - oracles.add(new PostgresTLPWhereOracle(globalState)); - oracles.add(new PostgresTLPHavingOracle(globalState)); - oracles.add(new PostgresTLPAggregateOracle(globalState)); - return new CompositeTestOracle(oracles, globalState); - } - }; + private static boolean determineDefaultTablespaceSupport() { + String osName = System.getProperty("os.name").toLowerCase(); + if (osName.contains("linux")) { + System.out.println("[INFO] Linux detected: Enabling tablespace testing by default"); + return true; + } else if (osName.contains("mac") || osName.contains("darwin")) { + System.out.println( + "[INFO] macOS detected: Disabling tablespace testing by default due to different /tmp handling. Override with --test-tablespaces=true and ensure proper directory permissions."); + return false; + } else if (osName.contains("windows")) { + System.out.println( + "[INFO] Windows detected: Disabling tablespace testing by default due to path format differences. Override with --test-tablespaces=true and use --tablespace-path to set a valid Windows path."); + return false; + } else { + System.out.println( + "[INFO] Unknown OS detected: Disabling tablespace testing by default for safety. Override with --test-tablespaces=true if your system supports PostgreSQL tablespaces."); + return false; + } + } + public static String getDefaultTablespacePath() { + String osName = System.getProperty("os.name").toLowerCase(); + if (osName.contains("windows")) { + // On Windows, use a path in the temp directory + return new File(System.getProperty("java.io.tmpdir"), "postgresql" + File.separator + "tablespace") + .getAbsolutePath(); + } else { + // On Unix-like systems, use /tmp + return "/tmp/postgresql/tablespace"; + } } @Override @@ -85,4 +78,43 @@ public List getTestOracleFactory() { return oracle; } + public String getTablespacePath() { + if (tablespacePath == null || tablespacePath.isBlank()) { + throw new AssertionError("Tablespace path is null or empty. Please configure --tablespace-path"); + } + + File path = new File(tablespacePath); + + // Check if the directory exists or can be created + if (!path.exists() && !path.mkdirs()) { + throw new AssertionError("Cannot create tablespace directory: " + tablespacePath + + ". Please ensure the parent directory exists and you have write permissions."); + } + + // Check if it's actually a directory + if (!path.isDirectory()) { + throw new AssertionError("Tablespace path is not a directory: " + tablespacePath); + } + + // Check write permissions + if (!path.canWrite()) { + throw new AssertionError("No write permissions for tablespace directory: " + tablespacePath + + ". Please ensure you have write permissions to this directory."); + } + + return tablespacePath; + } + + public boolean isTestTablespaces() { + // If the user explicitly set the value via command line, use that + // Otherwise, use the OS-dependent default + return testTablespaces || getDefaultTablespaceSupport(); + } + + private static boolean getDefaultTablespaceSupport() { + if (defaultTestTablespaces == null) { + defaultTestTablespaces = determineDefaultTablespaceSupport(); + } + return defaultTestTablespaces; + } } diff --git a/src/sqlancer/postgres/PostgresOracleFactory.java b/src/sqlancer/postgres/PostgresOracleFactory.java new file mode 100644 index 000000000..7c00021ed --- /dev/null +++ b/src/sqlancer/postgres/PostgresOracleFactory.java @@ -0,0 +1,117 @@ +package sqlancer.postgres; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CERTOracle; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.postgres.gen.PostgresCommon; +import sqlancer.postgres.gen.PostgresExpressionGenerator; +import sqlancer.postgres.oracle.PostgresFuzzer; +import sqlancer.postgres.oracle.PostgresPivotedQuerySynthesisOracle; +import sqlancer.postgres.oracle.tlp.PostgresTLPAggregateOracle; +import sqlancer.postgres.oracle.tlp.PostgresTLPHavingOracle; + +public enum PostgresOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create(PostgresGlobalState globalState) throws SQLException { + PostgresExpressionGenerator gen = new PostgresExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(PostgresCommon.getCommonExpressionErrors()) + .with(PostgresCommon.getCommonFetchErrors()) + .withRegex(PostgresCommon.getCommonExpressionRegexErrors()).build(); + return new NoRECOracle<>(globalState, gen, errors); + } + }, + PQS { + @Override + public TestOracle create(PostgresGlobalState globalState) throws SQLException { + return new PostgresPivotedQuerySynthesisOracle(globalState); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + }, + WHERE { + @Override + public TestOracle create(PostgresGlobalState globalState) throws SQLException { + PostgresExpressionGenerator gen = new PostgresExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(PostgresCommon.getCommonExpressionErrors()) + .with(PostgresCommon.getCommonFetchErrors()) + .withRegex(PostgresCommon.getCommonExpressionRegexErrors()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + + }, + HAVING { + @Override + public TestOracle create(PostgresGlobalState globalState) throws SQLException { + return new PostgresTLPHavingOracle(globalState); + } + + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(PostgresGlobalState globalState) throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(WHERE.create(globalState)); + oracles.add(HAVING.create(globalState)); + oracles.add(new PostgresTLPAggregateOracle(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + }, + CERT { + @Override + public TestOracle create(PostgresGlobalState globalState) throws SQLException { + PostgresExpressionGenerator gen = new PostgresExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(PostgresCommon.getCommonExpressionErrors()) + .withRegex(PostgresCommon.getCommonExpressionRegexErrors()) + .with(PostgresCommon.getCommonFetchErrors()).with(PostgresCommon.getCommonInsertUpdateErrors()) + .with(PostgresCommon.getGroupingErrors()).with(PostgresCommon.getCommonInsertUpdateErrors()) + .with(PostgresCommon.getCommonRangeExpressionErrors()).build(); + CERTOracle.CheckedFunction> rowCountParser = (rs) -> { + String content = rs.getString(1).trim(); + if (content.contains("Result") && content.contains("rows=")) { + try { + int ind = content.indexOf("rows="); + long number = Long.parseLong(content.substring(ind + 5).split(" ")[0]); + return Optional.of(number); + } catch (Exception e) { + } + } + return Optional.empty(); + }; + CERTOracle.CheckedFunction> queryPlanParser = (rs) -> { + String content = rs.getString(1).trim(); + String[] planPart = content.split("-> "); + String plan = planPart[planPart.length - 1]; + return Optional.of(plan.split(" ")[0].trim()); + }; + return new CERTOracle<>(globalState, gen, errors, rowCountParser, queryPlanParser); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + }, + FUZZER { + @Override + public TestOracle create(PostgresGlobalState globalState) throws Exception { + return new PostgresFuzzer(globalState); + } + + }; + +} diff --git a/src/sqlancer/postgres/PostgresProvider.java b/src/sqlancer/postgres/PostgresProvider.java index e9396a628..ec7978216 100644 --- a/src/sqlancer/postgres/PostgresProvider.java +++ b/src/sqlancer/postgres/PostgresProvider.java @@ -1,13 +1,21 @@ package sqlancer.postgres; +import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; import java.sql.Statement; +import java.util.ArrayList; import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; +import java.util.stream.Collectors; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.auto.service.AutoService; import sqlancer.AbstractAction; @@ -22,7 +30,6 @@ import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLQueryProvider; import sqlancer.common.query.SQLancerResultSet; -import sqlancer.postgres.PostgresOptions.PostgresOracleFactory; import sqlancer.postgres.gen.PostgresAlterTableGenerator; import sqlancer.postgres.gen.PostgresAnalyzeGenerator; import sqlancer.postgres.gen.PostgresClusterGenerator; @@ -30,6 +37,7 @@ import sqlancer.postgres.gen.PostgresDeleteGenerator; import sqlancer.postgres.gen.PostgresDiscardGenerator; import sqlancer.postgres.gen.PostgresDropIndexGenerator; +import sqlancer.postgres.gen.PostgresExplainGenerator; import sqlancer.postgres.gen.PostgresIndexGenerator; import sqlancer.postgres.gen.PostgresInsertGenerator; import sqlancer.postgres.gen.PostgresNotifyGenerator; @@ -38,6 +46,7 @@ import sqlancer.postgres.gen.PostgresSetGenerator; import sqlancer.postgres.gen.PostgresStatisticsGenerator; import sqlancer.postgres.gen.PostgresTableGenerator; +import sqlancer.postgres.gen.PostgresTableSpaceGenerator; import sqlancer.postgres.gen.PostgresTransactionGenerator; import sqlancer.postgres.gen.PostgresTruncateGenerator; import sqlancer.postgres.gen.PostgresUpdateGenerator; @@ -91,6 +100,7 @@ public enum Action implements AbstractAction { }), // CREATE_STATISTICS(PostgresStatisticsGenerator::insert), // DROP_STATISTICS(PostgresStatisticsGenerator::remove), // + ALTER_STATISTICS(PostgresStatisticsGenerator::alter), // DELETE(PostgresDeleteGenerator::create), // DISCARD(PostgresDiscardGenerator::create), // DROP_INDEX(PostgresDropIndexGenerator::create), // @@ -110,14 +120,16 @@ public enum Action implements AbstractAction { RESET_ROLE((g) -> new SQLQueryAdapter("RESET ROLE")), // COMMENT_ON(PostgresCommentGenerator::generate), // RESET((g) -> new SQLQueryAdapter("RESET ALL") /* - * https://www.postgresql.org/docs/devel/sql-reset.html TODO: also + * https://www.postgresql.org/docs/13/sql-reset.html TODO: also * configuration parameter */), // NOTIFY(PostgresNotifyGenerator::createNotify), // LISTEN((g) -> PostgresNotifyGenerator.createListen()), // UNLISTEN((g) -> PostgresNotifyGenerator.createUnlisten()), // CREATE_SEQUENCE(PostgresSequenceGenerator::createSequence), // - CREATE_VIEW(PostgresViewGenerator::create); + EXPLAIN(PostgresExplainGenerator::create), // + CREATE_VIEW(PostgresViewGenerator::create), // + CREATE_TABLESPACE(PostgresTableSpaceGenerator::generate); private final SQLQueryProvider sqlQueryProvider; @@ -142,6 +154,9 @@ protected static int mapActions(PostgresGlobalState globalState, Action a) { case CREATE_STATISTICS: nrPerformed = r.getInteger(0, 5); break; + case ALTER_STATISTICS: + nrPerformed = r.getInteger(0, 2); + break; case DISCARD: case DROP_INDEX: nrPerformed = r.getInteger(0, 5); @@ -178,12 +193,18 @@ protected static int mapActions(PostgresGlobalState globalState, Action a) { case CREATE_VIEW: nrPerformed = r.getInteger(0, 2); break; + case CREATE_TABLESPACE: + nrPerformed = r.getInteger(0, 2); + break; case UPDATE: nrPerformed = r.getInteger(0, 10); break; case INSERT: nrPerformed = r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); break; + case EXPLAIN: + nrPerformed = r.getInteger(0, 1); + break; default: throw new AssertionError(a); } @@ -267,12 +288,33 @@ public SQLConnection createDatabase(PostgresGlobalState globalState) throws SQLE } Connection con = DriverManager.getConnection("jdbc:" + entryURL, username, password); globalState.getState().logStatement(String.format("\\c %s;", entryDatabaseName)); - globalState.getState().logStatement("DROP DATABASE IF EXISTS " + databaseName); - createDatabaseCommand = getCreateDatabaseCommand(globalState); - globalState.getState().logStatement(createDatabaseCommand); + + String dropCommand = "DROP DATABASE"; + boolean forceDrop = Randomly.getBoolean(); + if (forceDrop) { + dropCommand += " FORCE"; + } + dropCommand += " IF EXISTS " + databaseName; + + globalState.getState().logStatement(dropCommand + ";"); try (Statement s = con.createStatement()) { - s.execute("DROP DATABASE IF EXISTS " + databaseName); + s.execute(dropCommand); + } catch (SQLException e) { + // If force fails, fall back to regular drop + if (forceDrop) { + String fallbackDrop = "DROP DATABASE IF EXISTS " + databaseName; + globalState.getState().logStatement(fallbackDrop + ";"); + try (Statement s = con.createStatement()) { + s.execute(fallbackDrop); + } + } else { + throw e; + } } + + // Create database section + createDatabaseCommand = getCreateDatabaseCommand(globalState); + globalState.getState().logStatement(createDatabaseCommand + ";"); try (Statement s = con.createStatement()) { s.execute(createDatabaseCommand); } @@ -325,18 +367,26 @@ protected void prepareTables(PostgresGlobalState globalState) throws Exception { private String getCreateDatabaseCommand(PostgresGlobalState state) { StringBuilder sb = new StringBuilder(); sb.append("CREATE DATABASE " + databaseName + " "); - if (Randomly.getBoolean() && ((PostgresOptions) state.getDbmsSpecificOptions()).testCollations) { + if (((PostgresOptions) state.getDbmsSpecificOptions()).testCollations) { if (Randomly.getBoolean()) { - sb.append("WITH ENCODING '"); - sb.append(Randomly.fromOptions("utf8")); - sb.append("' "); - } - for (String lc : Arrays.asList("LC_COLLATE", "LC_CTYPE")) { - if (!state.getCollates().isEmpty() && Randomly.getBoolean()) { - sb.append(String.format(" %s = '%s'", lc, Randomly.fromList(state.getCollates()))); + if (Randomly.getBoolean()) { + sb.append("WITH ENCODING '"); + sb.append(Randomly.fromOptions("utf8")); + sb.append("' "); } + if (Randomly.getBoolean() && !state.getCollates().isEmpty()) { + sb.append(String.format(" LOCALE = '%s' ", Randomly.fromList(state.getCollates()))); + } else { + for (String lc : Arrays.asList("LC_COLLATE", "LC_CTYPE")) { + if (!state.getCollates().isEmpty() && Randomly.getBoolean()) { + sb.append(String.format(" %s = '%s'", lc, Randomly.fromList(state.getCollates()))); + } + } + } + sb.append(" TEMPLATE template0"); } - sb.append(" TEMPLATE template0"); + } else { + sb.append("WITH ENCODING 'UTF8' TEMPLATE template0"); } return sb.toString(); } @@ -346,4 +396,75 @@ public String getDBMSName() { return "postgres"; } + @Override + public String getQueryPlan(String selectStr, PostgresGlobalState globalState) throws Exception { + String queryPlan = ""; + if (globalState.getOptions().logEachSelect()) { + globalState.getLogger().writeCurrent(selectStr); + try { + globalState.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + SQLQueryAdapter q = new SQLQueryAdapter(PostgresExplainGenerator.explain(selectStr), null); + try (SQLancerResultSet rs = q.executeAndGet(globalState)) { + while (rs.next()) { + queryPlan += rs.getString(1); + } + } catch (SQLException | AssertionError e) { + queryPlan = ""; + } + return formatQueryPlan(queryPlan); + } + + @Override + protected double[] initializeWeightedAverageReward() { + return new double[PostgresProvider.Action.values().length]; + } + + @Override + protected void executeMutator(int index, PostgresGlobalState globalState) throws Exception { + SQLQueryAdapter queryMutateTable = PostgresProvider.Action.values()[index].getQuery(globalState); + globalState.executeStatement(queryMutateTable); + } + + @Override + protected boolean addRowsToAllTables(PostgresGlobalState globalState) throws Exception { + List tablesNoRow = globalState.getSchema().getDatabaseTables().stream() + .filter(t -> t.getNrRows(globalState) == 0).collect(Collectors.toList()); + for (PostgresSchema.PostgresTable table : tablesNoRow) { + SQLQueryAdapter queryAddRows = PostgresInsertGenerator.insertRows(globalState, table); + globalState.executeStatement(queryAddRows); + } + return true; + } + + public String formatQueryPlan(String queryPlan) throws IOException { + ObjectMapper mapper = new ObjectMapper(); + JsonNode root = mapper.readTree(queryPlan).get(0).get("Plan"); + // Extract nodes using BFS algorithm + List nodeTypes = extractNodeTypesIterative(root); + return String.join(" ", nodeTypes); + } + + // BFS algorithm for traversing the Json Query Plan + private static List extractNodeTypesIterative(JsonNode root) { + List result = new ArrayList<>(); + Queue queue = new LinkedList<>(); + queue.add(root); + while (!queue.isEmpty()) { + JsonNode node = queue.poll(); + if (node.has("Node Type")) { + result.add(node.get("Node Type").asText()); + } + if (node.has("Plans") && node.get("Plans").isArray()) { + for (JsonNode plan : node.get("Plans")) { + queue.add(plan); + } + } + } + return result; + } + } diff --git a/src/sqlancer/postgres/PostgresSchema.java b/src/sqlancer/postgres/PostgresSchema.java index c99c8648e..82937557c 100644 --- a/src/sqlancer/postgres/PostgresSchema.java +++ b/src/sqlancer/postgres/PostgresSchema.java @@ -125,6 +125,12 @@ public static PostgresDataType getColumnType(String typeString) { case "character varying": case "name": case "regclass": + case "regnamespace": + case "regrole": + case "regtype": + case "regproc": + case "regprocedure": + case "regoper": return PostgresDataType.TEXT; case "numeric": return PostgresDataType.DECIMAL; @@ -164,6 +170,7 @@ public enum TableType { private final TableType tableType; private final List statistics; private final boolean isInsertable; + private final boolean isPartitioned; public PostgresTable(String tableName, List columns, List indexes, TableType tableType, List statistics, boolean isView, boolean isInsertable) { @@ -171,6 +178,18 @@ public PostgresTable(String tableName, List columns, List columns, List indexes, + TableType tableType, List statistics, boolean isView, boolean isInsertable, + boolean isPartitioned) { + super(tableName, columns, indexes, isView); + this.statistics = statistics; + this.isInsertable = isInsertable; + this.tableType = tableType; + this.isPartitioned = isPartitioned; } public List getStatistics() { @@ -185,6 +204,10 @@ public boolean isInsertable() { return isInsertable; } + public boolean isPartitioned() { + return isPartitioned; + } + } public static final class PostgresStatisticsObject { @@ -225,22 +248,23 @@ public static PostgresSchema fromConnection(SQLConnection con, String databaseNa List databaseTables = new ArrayList<>(); try (Statement s = con.createStatement()) { try (ResultSet rs = s.executeQuery( - "SELECT table_name, table_schema, table_type, is_insertable_into FROM information_schema.tables WHERE table_schema='public' OR table_schema LIKE 'pg_temp_%' ORDER BY table_name;")) { + "SELECT t.table_name, t.table_schema, t.table_type, t.is_insertable_into, c.relkind FROM information_schema.tables t JOIN pg_class c ON c.relname = t.table_name JOIN pg_namespace n ON n.oid = c.relnamespace AND n.nspname = t.table_schema WHERE t.table_schema='public' OR t.table_schema LIKE 'pg_temp_%' ORDER BY t.table_name;")) { while (rs.next()) { String tableName = rs.getString("table_name"); String tableTypeSchema = rs.getString("table_schema"); boolean isInsertable = rs.getBoolean("is_insertable_into"); + boolean isPartitioned = "p".equals(rs.getString("relkind")); // TODO: also check insertable // TODO: insert into view? - boolean isView = tableName.startsWith("v"); // tableTypeStr.contains("VIEW") || - // tableTypeStr.contains("LOCAL TEMPORARY") && - // !isInsertable; + boolean isView = matchesViewName(tableName); // tableTypeStr.contains("VIEW") || + // tableTypeStr.contains("LOCAL TEMPORARY") && + // !isInsertable; PostgresTable.TableType tableType = getTableType(tableTypeSchema); List databaseColumns = getTableColumns(con, tableName); List indexes = getIndexes(con, tableName); List statistics = getStatistics(con); PostgresTable t = new PostgresTable(tableName, databaseColumns, indexes, tableType, statistics, - isView, isInsertable); + isView, isInsertable, isPartitioned); for (PostgresColumn c : databaseColumns) { c.setTable(t); } diff --git a/src/sqlancer/postgres/PostgresToStringVisitor.java b/src/sqlancer/postgres/PostgresToStringVisitor.java index 1e95ae745..87bd3c429 100644 --- a/src/sqlancer/postgres/PostgresToStringVisitor.java +++ b/src/sqlancer/postgres/PostgresToStringVisitor.java @@ -1,5 +1,6 @@ package sqlancer.postgres; +import java.util.List; import java.util.Optional; import sqlancer.Randomly; @@ -11,6 +12,7 @@ import sqlancer.postgres.ast.PostgresBinaryLogicalOperation; import sqlancer.postgres.ast.PostgresCastOperation; import sqlancer.postgres.ast.PostgresCollate; +import sqlancer.postgres.ast.PostgresColumnReference; import sqlancer.postgres.ast.PostgresColumnValue; import sqlancer.postgres.ast.PostgresConstant; import sqlancer.postgres.ast.PostgresExpression; @@ -28,6 +30,10 @@ import sqlancer.postgres.ast.PostgresSelect.PostgresFromTable; import sqlancer.postgres.ast.PostgresSelect.PostgresSubquery; import sqlancer.postgres.ast.PostgresSimilarTo; +import sqlancer.postgres.ast.PostgresTableReference; +import sqlancer.postgres.ast.PostgresWindowFunction; +import sqlancer.postgres.ast.PostgresWindowFunction.WindowFrame; +import sqlancer.postgres.ast.PostgresWindowFunction.WindowSpecification; public final class PostgresToStringVisitor extends ToStringVisitor implements PostgresVisitor { @@ -46,6 +52,11 @@ public String get() { return sb.toString(); } + @Override + public void visit(PostgresColumnReference column) { + sb.append(column.getColumn().getFullQualifiedName()); + } + @Override public void visit(PostgresPostfixOperation op) { sb.append("("); @@ -87,6 +98,11 @@ public void visit(PostgresSubquery subquery) { sb.append(subquery.getName()); } + @Override + public void visit(PostgresTableReference ref) { + sb.append(ref.getTable().getName()); + } + @Override public void visit(PostgresSelect s) { sb.append("SELECT "); @@ -149,7 +165,7 @@ public void visit(PostgresSelect s) { sb.append(" WHERE "); visit(s.getWhereClause()); } - if (s.getGroupByExpressions().size() > 0) { + if (!s.getGroupByExpressions().isEmpty()) { sb.append(" GROUP BY "); visit(s.getGroupByExpressions()); } @@ -158,9 +174,9 @@ public void visit(PostgresSelect s) { visit(s.getHavingClause()); } - if (!s.getOrderByExpressions().isEmpty()) { + if (!s.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - visit(s.getOrderByExpressions()); + visit(s.getOrderByClauses()); } if (s.getLimitClause() != null) { sb.append(" LIMIT "); @@ -174,10 +190,9 @@ public void visit(PostgresSelect s) { } @Override - public void visit(PostgresOrderByTerm op) { - visit(op.getExpr()); - sb.append(" "); - sb.append(op.getOrder()); + public void visit(PostgresOrderByTerm term) { + visit(term.getExpr()); + sb.append(term.isAscending() ? " ASC" : " DESC"); } @Override @@ -350,4 +365,38 @@ public void visit(PostgresLikeOperation op) { super.visit((BinaryOperation) op); } + @Override + @SuppressWarnings("unchecked") + public void visit(PostgresWindowFunction windowFunction) { + sb.append(windowFunction.getFunctionName()); + sb.append("("); + visit(windowFunction.getArguments()); + sb.append(") OVER ("); + + WindowSpecification spec = windowFunction.getWindowSpec(); + if (!spec.getPartitionBy().isEmpty()) { + sb.append("PARTITION BY "); + visit(spec.getPartitionBy()); + } + + if (!spec.getOrderBy().isEmpty()) { + if (!spec.getPartitionBy().isEmpty()) { + sb.append(" "); + } + sb.append("ORDER BY "); + visit((List) (List) spec.getOrderBy()); + } + + if (spec.getFrame() != null) { + sb.append(" "); + WindowFrame frame = spec.getFrame(); + sb.append(frame.getType().getSQL()); + sb.append(" BETWEEN "); + visit(frame.getStartExpr()); + sb.append(" AND "); + visit(frame.getEndExpr()); + } + + sb.append(")"); + } } diff --git a/src/sqlancer/postgres/PostgresVisitor.java b/src/sqlancer/postgres/PostgresVisitor.java index d66dd1271..d5bf71e81 100644 --- a/src/sqlancer/postgres/PostgresVisitor.java +++ b/src/sqlancer/postgres/PostgresVisitor.java @@ -9,6 +9,7 @@ import sqlancer.postgres.ast.PostgresBinaryLogicalOperation; import sqlancer.postgres.ast.PostgresCastOperation; import sqlancer.postgres.ast.PostgresCollate; +import sqlancer.postgres.ast.PostgresColumnReference; import sqlancer.postgres.ast.PostgresColumnValue; import sqlancer.postgres.ast.PostgresConstant; import sqlancer.postgres.ast.PostgresExpression; @@ -24,6 +25,8 @@ import sqlancer.postgres.ast.PostgresSelect.PostgresFromTable; import sqlancer.postgres.ast.PostgresSelect.PostgresSubquery; import sqlancer.postgres.ast.PostgresSimilarTo; +import sqlancer.postgres.ast.PostgresTableReference; +import sqlancer.postgres.ast.PostgresWindowFunction; import sqlancer.postgres.gen.PostgresExpressionGenerator; public interface PostgresVisitor { @@ -34,6 +37,10 @@ public interface PostgresVisitor { void visit(PostgresColumnValue c); + void visit(PostgresColumnReference c); + + void visit(PostgresTableReference tb); + void visit(PostgresPrefixOperation op); void visit(PostgresSelect op); @@ -66,6 +73,8 @@ public interface PostgresVisitor { void visit(PostgresLikeOperation op); + void visit(PostgresWindowFunction windowFunction); + default void visit(PostgresExpression expression) { if (expression instanceof PostgresConstant) { visit((PostgresConstant) expression); @@ -103,6 +112,12 @@ default void visit(PostgresExpression expression) { visit((PostgresSubquery) expression); } else if (expression instanceof PostgresLikeOperation) { visit((PostgresLikeOperation) expression); + } else if (expression instanceof PostgresColumnReference) { + visit((PostgresColumnReference) expression); + } else if (expression instanceof PostgresTableReference) { + visit((PostgresTableReference) expression); + } else if (expression instanceof PostgresWindowFunction) { + visit((PostgresWindowFunction) expression); } else { throw new AssertionError(expression); } diff --git a/src/sqlancer/postgres/ast/PostgresBinaryComparisonOperation.java b/src/sqlancer/postgres/ast/PostgresBinaryComparisonOperation.java index 95efe5b72..b77060dfd 100644 --- a/src/sqlancer/postgres/ast/PostgresBinaryComparisonOperation.java +++ b/src/sqlancer/postgres/ast/PostgresBinaryComparisonOperation.java @@ -1,5 +1,6 @@ package sqlancer.postgres.ast; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.ast.BinaryOperatorNode; import sqlancer.common.ast.BinaryOperatorNode.Operator; @@ -126,7 +127,7 @@ public PostgresConstant getExpectedValue() { PostgresConstant leftExpectedValue = getLeft().getExpectedValue(); PostgresConstant rightExpectedValue = getRight().getExpectedValue(); if (leftExpectedValue == null || rightExpectedValue == null) { - return null; + throw new IgnoreMeException(); } return getOp().getExpectedValue(leftExpectedValue, rightExpectedValue); } diff --git a/src/sqlancer/postgres/ast/PostgresColumnReference.java b/src/sqlancer/postgres/ast/PostgresColumnReference.java new file mode 100644 index 000000000..0d835bf89 --- /dev/null +++ b/src/sqlancer/postgres/ast/PostgresColumnReference.java @@ -0,0 +1,15 @@ +package sqlancer.postgres.ast; + +import sqlancer.postgres.PostgresSchema.PostgresColumn; + +public class PostgresColumnReference implements PostgresExpression { + private final PostgresColumn c; + + public PostgresColumnReference(PostgresColumn c) { + this.c = c; + } + + public PostgresColumn getColumn() { + return c; + } +} diff --git a/src/sqlancer/postgres/ast/PostgresExpression.java b/src/sqlancer/postgres/ast/PostgresExpression.java index 96ddfa433..433e2207c 100644 --- a/src/sqlancer/postgres/ast/PostgresExpression.java +++ b/src/sqlancer/postgres/ast/PostgresExpression.java @@ -1,8 +1,10 @@ package sqlancer.postgres.ast; +import sqlancer.common.ast.newast.Expression; +import sqlancer.postgres.PostgresSchema.PostgresColumn; import sqlancer.postgres.PostgresSchema.PostgresDataType; -public interface PostgresExpression { +public interface PostgresExpression extends Expression { default PostgresDataType getExpressionType() { return null; diff --git a/src/sqlancer/postgres/ast/PostgresFunctionWithUnknownResult.java b/src/sqlancer/postgres/ast/PostgresFunctionWithUnknownResult.java index 3357a7db5..287f46784 100644 --- a/src/sqlancer/postgres/ast/PostgresFunctionWithUnknownResult.java +++ b/src/sqlancer/postgres/ast/PostgresFunctionWithUnknownResult.java @@ -19,11 +19,11 @@ public enum PostgresFunctionWithUnknownResult { TEXT("text", PostgresDataType.TEXT, PostgresDataType.INET), INET_SAME_FAMILY("inet_same_family", PostgresDataType.BOOLEAN, PostgresDataType.INET, PostgresDataType.INET), - // https://www.postgresql.org/docs/devel/functions-admin.html#FUNCTIONS-ADMIN-SIGNAL-TABLE + // https://www.postgresql.org/docs/13/functions-admin.html#FUNCTIONS-ADMIN-SIGNAL-TABLE // PG_RELOAD_CONF("pg_reload_conf", PostgresDataType.BOOLEAN), // too much output // PG_ROTATE_LOGFILE("pg_rotate_logfile", PostgresDataType.BOOLEAN), prints warning - // https://www.postgresql.org/docs/devel/functions-info.html#FUNCTIONS-INFO-SESSION-TABLE + // https://www.postgresql.org/docs/13/functions-info.html#FUNCTIONS-INFO-SESSION-TABLE CURRENT_DATABASE("current_database", PostgresDataType.TEXT), // name // CURRENT_QUERY("current_query", PostgresDataType.TEXT), // can generate false positives CURRENT_SCHEMA("current_schema", PostgresDataType.TEXT), // name @@ -87,7 +87,7 @@ public PostgresExpression[] getArguments(PostgresDataType returnType, PostgresEx TO_HEX("to_hex", PostgresDataType.INT, PostgresDataType.TEXT), TRANSLATE("translate", PostgresDataType.TEXT, PostgresDataType.TEXT, PostgresDataType.TEXT, PostgresDataType.TEXT), // mathematical functions - // https://www.postgresql.org/docs/9.5/functions-math.html + // https://www.postgresql.org/docs/13/functions-math.html ABS("abs", PostgresDataType.REAL, PostgresDataType.REAL), CBRT("cbrt", PostgresDataType.REAL, PostgresDataType.REAL), CEILING("ceiling", PostgresDataType.REAL), // DEGREES("degrees", PostgresDataType.REAL), EXP("exp", PostgresDataType.REAL), LN("ln", PostgresDataType.REAL), @@ -98,7 +98,7 @@ public PostgresExpression[] getArguments(PostgresDataType returnType, PostgresEx FLOOR("floor", PostgresDataType.REAL), // trigonometric functions - complete - // https://www.postgresql.org/docs/12/functions-math.html#FUNCTIONS-MATH-TRIG-TABLE + // https://www.postgresql.org/docs/13/functions-math.html#FUNCTIONS-MATH-TRIG-TABLE ACOS("acos", PostgresDataType.REAL), // ACOSD("acosd", PostgresDataType.REAL), // ASIN("asin", PostgresDataType.REAL), // @@ -117,7 +117,7 @@ public PostgresExpression[] getArguments(PostgresDataType returnType, PostgresEx TAND("tand", PostgresDataType.REAL), // // hyperbolic functions - complete - // https://www.postgresql.org/docs/12/functions-math.html#FUNCTIONS-MATH-HYP-TABLE + // https://www.postgresql.org/docs/13/functions-math.html#FUNCTIONS-MATH-HYP-TABLE SINH("sinh", PostgresDataType.REAL), // COSH("cosh", PostgresDataType.REAL), // TANH("tanh", PostgresDataType.REAL), // @@ -125,12 +125,12 @@ public PostgresExpression[] getArguments(PostgresDataType returnType, PostgresEx ACOSH("acosh", PostgresDataType.REAL), // ATANH("atanh", PostgresDataType.REAL), // - // https://www.postgresql.org/docs/devel/functions-binarystring.html + // https://www.postgresql.org/docs/13/functions-binarystring.html GET_BIT("get_bit", PostgresDataType.INT, PostgresDataType.TEXT, PostgresDataType.INT), GET_BYTE("get_byte", PostgresDataType.INT, PostgresDataType.TEXT, PostgresDataType.INT), // range functions - // https://www.postgresql.org/docs/devel/functions-range.html#RANGE-FUNCTIONS-TABLE + // https://www.postgresql.org/docs/13/functions-range.html#RANGE-FUNCTIONS-TABLE RANGE_LOWER("lower", PostgresDataType.INT, PostgresDataType.RANGE), // RANGE_UPPER("upper", PostgresDataType.INT, PostgresDataType.RANGE), // RANGE_ISEMPTY("isempty", PostgresDataType.BOOLEAN, PostgresDataType.RANGE), // @@ -140,7 +140,7 @@ public PostgresExpression[] getArguments(PostgresDataType returnType, PostgresEx RANGE_UPPER_INF("upper_inf", PostgresDataType.BOOLEAN, PostgresDataType.RANGE), // RANGE_MERGE("range_merge", PostgresDataType.RANGE, PostgresDataType.RANGE, PostgresDataType.RANGE), // - // https://www.postgresql.org/docs/devel/functions-admin.html#FUNCTIONS-ADMIN-DBSIZE + // https://www.postgresql.org/docs/13/functions-admin.html#FUNCTIONS-ADMIN-DBSIZE GET_COLUMN_SIZE("get_column_size", PostgresDataType.INT, PostgresDataType.TEXT); // PG_DATABASE_SIZE("pg_database_size", PostgresDataType.INT, PostgresDataType.INT); // PG_SIZE_BYTES("pg_size_bytes", PostgresDataType.INT, PostgresDataType.TEXT); diff --git a/src/sqlancer/postgres/ast/PostgresJoin.java b/src/sqlancer/postgres/ast/PostgresJoin.java index 57b92ddf3..ef6dfab1f 100644 --- a/src/sqlancer/postgres/ast/PostgresJoin.java +++ b/src/sqlancer/postgres/ast/PostgresJoin.java @@ -1,9 +1,18 @@ package sqlancer.postgres.ast; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.PostgresSchema.PostgresColumn; import sqlancer.postgres.PostgresSchema.PostgresDataType; +import sqlancer.postgres.PostgresSchema.PostgresTable; +import sqlancer.postgres.gen.PostgresExpressionGenerator; -public class PostgresJoin implements PostgresExpression { +public class PostgresJoin implements PostgresExpression, Join { public enum PostgresJoinType { INNER, LEFT, RIGHT, FULL, CROSS; @@ -12,22 +21,85 @@ public static PostgresJoinType getRandom() { return Randomly.fromOptions(values()); } + public static PostgresJoinType getRandomExcept(PostgresJoinType... exclude) { + PostgresJoinType[] values = Arrays.stream(values()).filter(m -> !Arrays.asList(exclude).contains(m)) + .toArray(PostgresJoinType[]::new); + return Randomly.fromOptions(values); + } + } private final PostgresExpression tableReference; - private final PostgresExpression onClause; - private final PostgresJoinType type; + private PostgresExpression onClause; + private PostgresJoinType type; + private final PostgresExpression leftTable; + private final PostgresExpression rightTable; public PostgresJoin(PostgresExpression tableReference, PostgresExpression onClause, PostgresJoinType type) { this.tableReference = tableReference; this.onClause = onClause; this.type = type; + this.leftTable = null; + this.rightTable = null; + } + + public PostgresJoin(PostgresExpression leftTable, PostgresExpression rightTable, PostgresJoinType joinType, + PostgresExpression whereCondition) { + this.leftTable = leftTable; + this.rightTable = rightTable; + this.type = joinType; + this.onClause = whereCondition; + this.tableReference = null; + } + + public static PostgresJoin createJoin(PostgresExpression left, PostgresExpression right, PostgresJoinType type, + PostgresExpression onClause) { + if (type == PostgresJoinType.CROSS) { + return new PostgresJoin(left, right, type, null); + } else { + return new PostgresJoin(left, right, type, onClause); + } + } + + public static List getJoins(List tableList, + PostgresGlobalState globalState) { + // Clone Table to prevent the original list from being manipulated + List tbl = new ArrayList<>(tableList); + List joinExpressions = new ArrayList<>(); + while (tbl.size() >= 2 && Randomly.getBoolean()) { + PostgresTableReference left = (PostgresTableReference) tbl.remove(0); + PostgresTableReference right = (PostgresTableReference) tbl.remove(0); + List columns = new ArrayList<>(); + columns.addAll(left.getTable().getColumns()); + columns.addAll(right.getTable().getColumns()); + PostgresExpressionGenerator joinGen = new PostgresExpressionGenerator(globalState).setColumns(columns); + joinExpressions.add(PostgresJoin.createJoin(left, right, PostgresJoinType.getRandom(), + joinGen.generateExpression(0, PostgresDataType.BOOLEAN))); + } + return joinExpressions; + } + + @Override + public void setOnClause(PostgresExpression clause) { + this.onClause = clause; + } + + public void setType(PostgresJoinType type) { + this.type = type; } public PostgresExpression getTableReference() { return tableReference; } + public PostgresExpression getLeftTable() { + return leftTable; + } + + public PostgresExpression getRightTable() { + return rightTable; + } + public PostgresExpression getOnClause() { return onClause; } diff --git a/src/sqlancer/postgres/ast/PostgresOrderByTerm.java b/src/sqlancer/postgres/ast/PostgresOrderByTerm.java index 20f93536e..76257215b 100644 --- a/src/sqlancer/postgres/ast/PostgresOrderByTerm.java +++ b/src/sqlancer/postgres/ast/PostgresOrderByTerm.java @@ -5,8 +5,10 @@ public class PostgresOrderByTerm implements PostgresExpression { - private final PostgresOrder order; private final PostgresExpression expr; + private final PostgresOrder order; + private final int limit; + private final boolean ties; public enum PostgresOrder { ASC, DESC; @@ -17,18 +19,45 @@ public static PostgresOrder getRandomOrder() { } public PostgresOrderByTerm(PostgresExpression expr, PostgresOrder order) { + if (expr == null) { + throw new IllegalArgumentException("Expression cannot be null"); + } this.expr = expr; this.order = order; + + if (Randomly.getBooleanWithRatherLowProbability()) { + this.limit = (int) Randomly.getPositiveOrZeroNonCachedInteger(); + this.ties = true; + } else { + this.limit = 0; + this.ties = false; + } + } - public PostgresOrder getOrder() { - return order; + // Constructor for window functions, might be removed in the future to have only one constructor + public PostgresOrderByTerm(PostgresExpression expr, boolean ascending) { + if (expr == null) { + throw new IllegalArgumentException("Expression cannot be null"); + } + this.expr = expr; + this.order = ascending ? PostgresOrder.ASC : PostgresOrder.DESC; + this.limit = 0; + this.ties = false; } public PostgresExpression getExpr() { return expr; } + public PostgresOrder getOrder() { + return order; + } + + public boolean isAscending() { + return order == PostgresOrder.ASC; + } + @Override public PostgresConstant getExpectedValue() { throw new AssertionError(this); @@ -39,4 +68,12 @@ public PostgresDataType getExpressionType() { return null; } + @Override + public String toString() { + if (ties) { + return String.format("%s %s FETCH FIRST %d WITH TIES", expr, order, limit); + } else { + return String.format("%s %s", expr, order); + } + } } diff --git a/src/sqlancer/postgres/ast/PostgresSelect.java b/src/sqlancer/postgres/ast/PostgresSelect.java index 70a42650f..1d7172e3e 100644 --- a/src/sqlancer/postgres/ast/PostgresSelect.java +++ b/src/sqlancer/postgres/ast/PostgresSelect.java @@ -1,19 +1,29 @@ package sqlancer.postgres.ast; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import sqlancer.Randomly; import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.postgres.PostgresSchema.PostgresColumn; import sqlancer.postgres.PostgresSchema.PostgresDataType; import sqlancer.postgres.PostgresSchema.PostgresTable; +import sqlancer.postgres.PostgresVisitor; +import sqlancer.postgres.ast.PostgresWindowFunction.WindowFrame; -public class PostgresSelect extends SelectBase implements PostgresExpression { +public class PostgresSelect extends SelectBase + implements PostgresExpression, Select { private SelectType selectOption = SelectType.ALL; private List joinClauses = Collections.emptyList(); private PostgresExpression distinctOnClause; private ForClause forClause; + private List windowFunctions = new ArrayList<>(); + private final Map windowDefinitions = new HashMap<>(); public enum ForClause { UPDATE("UPDATE"), NO_KEY_UPDATE("NO KEY UPDATE"), SHARE("SHARE"), KEY_SHARE("KEY SHARE"); @@ -33,6 +43,53 @@ public static ForClause getRandom() { } } + public static class WindowDefinition { + private final List partitionBy; + private final List orderBy; + private final WindowFrame frame; + + public WindowDefinition(List partitionBy, List orderBy, + WindowFrame frame) { + this.partitionBy = partitionBy; + this.orderBy = orderBy; + this.frame = frame; + } + + public List getPartitionBy() { + return partitionBy; + } + + public List getOrderBy() { + return orderBy; + } + + public WindowFrame getFrame() { + return frame; + } + } + + // Getters setters for windowfunctions + public List getWindowFunctions() { + return windowFunctions; + } + + public void setWindowFunctions(List windowFunctions) { + this.windowFunctions = windowFunctions; + } + + // Add methods for window definitions + public void addWindowDefinition(String name, WindowDefinition definition) { + windowDefinitions.put(name, definition); + } + + public WindowDefinition getWindowDefinition(String name) { + return windowDefinitions.get(name); + } + + public Map getWindowDefinitions() { + return windowDefinitions; + } + public static class PostgresFromTable implements PostgresExpression { private final PostgresTable t; private final boolean only; @@ -111,11 +168,13 @@ public PostgresDataType getExpressionType() { return null; } + @Override public void setJoinClauses(List joinStatements) { this.joinClauses = joinStatements; } + @Override public List getJoinClauses() { return joinClauses; } @@ -132,4 +191,8 @@ public ForClause getForClause() { return forClause; } + @Override + public String asString() { + return PostgresVisitor.asString(this); + } } diff --git a/src/sqlancer/postgres/ast/PostgresTableReference.java b/src/sqlancer/postgres/ast/PostgresTableReference.java new file mode 100644 index 000000000..2abf8a7d2 --- /dev/null +++ b/src/sqlancer/postgres/ast/PostgresTableReference.java @@ -0,0 +1,15 @@ +package sqlancer.postgres.ast; + +import sqlancer.postgres.PostgresSchema.PostgresTable; + +public class PostgresTableReference implements PostgresExpression { + private final PostgresTable table; + + public PostgresTableReference(PostgresTable table) { + this.table = table; + } + + public PostgresTable getTable() { + return table; + } +} diff --git a/src/sqlancer/postgres/ast/PostgresWindowFunction.java b/src/sqlancer/postgres/ast/PostgresWindowFunction.java new file mode 100644 index 000000000..15f87364f --- /dev/null +++ b/src/sqlancer/postgres/ast/PostgresWindowFunction.java @@ -0,0 +1,101 @@ +package sqlancer.postgres.ast; + +import java.util.List; + +import sqlancer.postgres.PostgresSchema.PostgresDataType; + +public class PostgresWindowFunction implements PostgresExpression { + + private final String functionName; + private final List arguments; + private final WindowSpecification windowSpec; + private final PostgresDataType returnType; + + public PostgresWindowFunction(String functionName, List arguments, + WindowSpecification windowSpec, PostgresDataType returnType) { + this.functionName = functionName; + this.arguments = arguments; + this.windowSpec = windowSpec; + this.returnType = returnType; + } + + public String getFunctionName() { + return functionName; + } + + public List getArguments() { + return arguments; + } + + public WindowSpecification getWindowSpec() { + return windowSpec; + } + + @Override + public PostgresDataType getExpressionType() { + return returnType; + } + + public static class WindowSpecification { + private final List partitionBy; + private final List orderBy; + private final WindowFrame frame; + + public WindowSpecification(List partitionBy, List orderBy, + WindowFrame frame) { + this.partitionBy = partitionBy; + this.orderBy = orderBy; + this.frame = frame; + } + + public List getPartitionBy() { + return partitionBy; + } + + public List getOrderBy() { + return orderBy; + } + + public WindowFrame getFrame() { + return frame; + } + } + + public static class WindowFrame { + public enum FrameType { + ROWS("ROWS"), RANGE("RANGE"); + + private final String sql; + + FrameType(String sql) { + this.sql = sql; + } + + public String getSQL() { + return sql; + } + } + + private final FrameType type; + private final PostgresExpression startExpr; + private final PostgresExpression endExpr; + + public WindowFrame(FrameType type, PostgresExpression startExpr, PostgresExpression endExpr) { + this.type = type; + this.startExpr = startExpr; + this.endExpr = endExpr; + } + + public FrameType getType() { + return type; + } + + public PostgresExpression getStartExpr() { + return startExpr; + } + + public PostgresExpression getEndExpr() { + return endExpr; + } + } +} diff --git a/src/sqlancer/postgres/gen/PostgresAlterTableGenerator.java b/src/sqlancer/postgres/gen/PostgresAlterTableGenerator.java index ac2964661..69b509f60 100644 --- a/src/sqlancer/postgres/gen/PostgresAlterTableGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresAlterTableGenerator.java @@ -34,6 +34,7 @@ protected enum Action { ALTER_COLUMN_SET_ATTRIBUTE_OPTION, // ALTER [ COLUMN ] column SET ( attribute_option = value [, ... ] ) ALTER_COLUMN_RESET_ATTRIBUTE_OPTION, // ALTER [ COLUMN ] column RESET ( attribute_option [, ... ] ) ALTER_COLUMN_SET_STORAGE, // ALTER [ COLUMN ] column SET STORAGE { PLAIN | EXTERNAL | EXTENDED | MAIN } + ALTER_COLUMN_DROP_EXPRESSION, // ALTER [ COLUMN ] column DROP EXPRESSION [ IF EXISTS ] ADD_TABLE_CONSTRAINT, // ADD table_constraint [ NOT VALID ] ADD_TABLE_CONSTRAINT_USING_INDEX, // ADD table_constraint_using_index VALIDATE_CONSTRAINT, // VALIDATE CONSTRAINT constraint_name @@ -48,9 +49,12 @@ protected enum Action { SET_LOGGED_UNLOGGED, // NOT_OF, // OWNER_TO, // - REPLICA_IDENTITY + REPLICA_IDENTITY, // RENAME COLUMN old_name TO new_name (for views) + ALTER_VIEW_RENAME_COLUMN // RENAME COLUMN old_name TO new_name (for views) } + private static final List VIEW_ACTIONS = List.of(Action.ALTER_VIEW_RENAME_COLUMN); + public PostgresAlterTableGenerator(PostgresTable randomTable, PostgresGlobalState globalState, boolean generateOnlyKnown) { this.randomTable = randomTable; @@ -98,6 +102,20 @@ public List getActions(ExpectedErrors errors) { // make it more likely that the ALTER TABLE succeeds action = Randomly.subset(Randomly.smallNumber(), Action.values()); } + + // If this is a view, only allow view-compatible operations + if (randomTable.isView()) { + // Remove all non-view operations + action.removeIf(a -> !VIEW_ACTIONS.contains(a)); + // If no view operations remain, add a random view operation + if (action.isEmpty()) { + action.add(VIEW_ACTIONS.get(r.getInteger(0, VIEW_ACTIONS.size() - 1))); + } + } else { + // Remove view-specific actions if this is a table + action.removeIf(VIEW_ACTIONS::contains); + } + if (randomTable.getColumns().size() == 1) { action.remove(Action.ALTER_TABLE_DROP_COLUMN); } @@ -109,6 +127,9 @@ public List getActions(ExpectedErrors errors) { if (!randomTable.hasIndexes()) { action.remove(Action.ADD_TABLE_CONSTRAINT_USING_INDEX); } + if (randomTable.isPartitioned()) { + action.remove(Action.SET_LOGGED_UNLOGGED); + } if (action.isEmpty()) { throw new IgnoreMeException(); } @@ -120,14 +141,24 @@ public SQLQueryAdapter generate() { int i = 0; List action = getActions(errors); StringBuilder sb = new StringBuilder(); - sb.append("ALTER TABLE "); - if (Randomly.getBoolean()) { - sb.append(" ONLY"); - errors.add("cannot use ONLY for foreign key on partitioned table"); + + // Check if we're dealing with a view operation + boolean isViewOperation = action.contains(Action.ALTER_VIEW_RENAME_COLUMN); + + if (isViewOperation) { + sb.append("ALTER VIEW "); + } else { + sb.append("ALTER TABLE "); + if (Randomly.getBoolean()) { + sb.append(" ONLY"); + errors.add("cannot use ONLY for foreign key on partitioned table"); + } } + sb.append(" "); sb.append(randomTable.getName()); sb.append(" "); + for (Action a : action) { if (i++ != 0) { sb.append(", "); @@ -206,6 +237,9 @@ public SQLQueryAdapter generate() { sb.append("DROP NOT NULL"); errors.add("is in a primary key"); errors.add("is an identity column"); + errors.add("is in index used as replica identity"); + // PG18 update: otherwise we need to encode contraint inheritance info in PostgreColumn + errors.add("cannot drop inherited constraint"); } break; case ALTER_COLUMN_SET_STATISTICS: @@ -249,6 +283,19 @@ public SQLQueryAdapter generate() { errors.add("can only have storage"); errors.add("is an identity column"); break; + case ALTER_COLUMN_DROP_EXPRESSION: + alterColumn(randomTable, sb); + sb.append("DROP EXPRESSION"); + if (Randomly.getBoolean()) { + sb.append(" IF EXISTS"); + } + errors.add("is not a generated column"); + errors.add("is not a stored generated column"); + errors.add("cannot drop expression from inherited column"); + errors.add("cannot drop generation expression from inherited column"); + errors.add("must be applied to child tables too"); + errors.add("cannot drop expression from column"); + break; case ADD_TABLE_CONSTRAINT: sb.append("ADD "); sb.append("CONSTRAINT " + r.getAlphabeticChar() + " "); @@ -257,6 +304,7 @@ public SQLQueryAdapter generate() { errors.add("multiple primary keys for table"); errors.add("could not create unique index"); errors.add("contains null values"); + errors.add("is not marked NOT NULL"); errors.add("cannot cast type"); errors.add("unsupported PRIMARY KEY constraint with partition key definition"); errors.add("unsupported UNIQUE constraint with partition key definition"); @@ -295,6 +343,7 @@ public SQLQueryAdapter generate() { errors.add("appears twice in unique constraint"); errors.add("appears twice in primary key constraint"); errors.add("contains null values"); + errors.add("is not marked NOT NULL"); errors.add("insufficient columns in PRIMARY KEY constraint definition"); errors.add("which is part of the partition key"); break; @@ -363,6 +412,17 @@ public SQLQueryAdapter generate() { errors.add("cannot use invalid index"); } break; + case ALTER_VIEW_RENAME_COLUMN: + sb.append("RENAME COLUMN "); + PostgresColumn columnToRename = randomTable.getRandomColumn(); + sb.append(columnToRename.getName()); + sb.append(" TO "); + sb.append("new_" + columnToRename.getName() + "_" + r.getInteger(1, 1000)); + errors.add("column does not exist"); + errors.add("column name already exists"); + errors.add("cannot rename column of view"); + errors.add("permission denied"); + break; default: throw new AssertionError(a); } diff --git a/src/sqlancer/postgres/gen/PostgresCommon.java b/src/sqlancer/postgres/gen/PostgresCommon.java index a3e04d046..eeb160a56 100644 --- a/src/sqlancer/postgres/gen/PostgresCommon.java +++ b/src/sqlancer/postgres/gen/PostgresCommon.java @@ -5,6 +5,7 @@ import java.util.List; import java.util.concurrent.ThreadLocalRandom; import java.util.function.Function; +import java.util.regex.Pattern; import java.util.stream.Collectors; import sqlancer.IgnoreMeException; @@ -22,7 +23,9 @@ public final class PostgresCommon { private PostgresCommon() { } - public static void addCommonFetchErrors(ExpectedErrors errors) { + public static List getCommonFetchErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("FULL JOIN is only supported with merge-joinable or hash-joinable join conditions"); errors.add("but it cannot be referenced from this part of the query"); errors.add("missing FROM-clause entry for table"); @@ -32,14 +35,33 @@ public static void addCommonFetchErrors(ExpectedErrors errors) { errors.add("non-integer constant in GROUP BY"); errors.add("must appear in the GROUP BY clause or be used in an aggregate function"); errors.add("GROUP BY position"); + + return errors; } - public static void addCommonTableErrors(ExpectedErrors errors) { + public static void addCommonFetchErrors(ExpectedErrors errors) { + errors.addAll(getCommonFetchErrors()); + } + + public static List getCommonTableErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("is not commutative"); // exclude errors.add("operator requires run-time type coercion"); // exclude + errors.add("partitioned tables cannot be unlogged"); + + return errors; } - public static void addCommonExpressionErrors(ExpectedErrors errors) { + public static void addCommonTableErrors(ExpectedErrors errors) { + errors.addAll(getCommonTableErrors()); + } + + public static List getCommonExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("for encoding \"SQL_ASCII\" does not exist"); + errors.add("invalid byte sequence for encoding"); errors.add("You might need to add explicit type casts"); errors.add("invalid regular expression"); errors.add("could not determine which collation to use"); @@ -53,7 +75,6 @@ public static void addCommonExpressionErrors(ExpectedErrors errors) { errors.add("invalid hexadecimal digit"); errors.add("invalid hexadecimal data: odd number of digits"); errors.add("zero raised to a negative power is undefined"); - errors.add("cannot convert infinity to numeric"); errors.add("division by zero"); errors.add("invalid input syntax for type money"); errors.add("invalid input syntax for type"); @@ -64,14 +85,32 @@ public static void addCommonExpressionErrors(ExpectedErrors errors) { errors.add("a negative number raised to a non-integer power yields a complex result"); errors.add("could not determine polymorphic type because input has type unknown"); errors.add("character number must be positive"); - addToCharFunctionErrors(errors); - addBitStringOperationErrors(errors); - addFunctionErrors(errors); - addCommonRangeExpressionErrors(errors); - addCommonRegexExpressionErrors(errors); + errors.addAll(getToCharFunctionErrors()); + errors.addAll(getBitStringOperationErrors()); + errors.addAll(getFunctionErrors()); + errors.addAll(getCommonRangeExpressionErrors()); + errors.addAll(getCommonRegexExpressionErrors()); + + return errors; + } + + public static List getCommonExpressionRegexErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add(Pattern.compile("cannot convert infinity to \\w+")); + errors.addAll(getFunctionRegexErrors()); + + return errors; + } + + public static void addCommonExpressionErrors(ExpectedErrors errors) { + errors.addAll(getCommonExpressionErrors()); + errors.addAllRegexes(getCommonExpressionRegexErrors()); } - private static void addToCharFunctionErrors(ExpectedErrors errors) { + private static List getToCharFunctionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("multiple decimal points"); errors.add("and decimal point together"); errors.add("multiple decimal points"); @@ -83,16 +122,25 @@ private static void addToCharFunctionErrors(ExpectedErrors errors) { errors.add("cannot use \"S\" and \"PL\" together"); errors.add("cannot use \"PR\" and \"S\"/\"PL\"/\"MI\"/\"SG\" together"); errors.add("is not a number"); + errors.add("\"EEEE\" must be the last pattern used"); + + return errors; } - private static void addBitStringOperationErrors(ExpectedErrors errors) { + private static List getBitStringOperationErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("cannot XOR bit strings of different sizes"); errors.add("cannot AND bit strings of different sizes"); errors.add("cannot OR bit strings of different sizes"); errors.add("must be type boolean, not type text"); + + return errors; } - private static void addFunctionErrors(ExpectedErrors errors) { + private static List getFunctionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("out of valid range"); // get_bit/get_byte errors.add("cannot take logarithm of a negative number"); errors.add("cannot take logarithm of zero"); @@ -101,26 +149,82 @@ private static void addFunctionErrors(ExpectedErrors errors) { errors.add("requested character not valid for encoding"); // chr errors.add("requested length too large"); // repeat errors.add("invalid memory alloc request size"); // repeat - errors.add("encoding conversion from UTF8 to ASCII not supported"); // to_ascii + errors.add("negative substring length not allowed"); // substr errors.add("invalid mask length"); // set_masklen + + return errors; + } + + private static List getFunctionRegexErrors() { + ArrayList errors = new ArrayList<>(); + /* + * PostgreSQL support only a few conversion variants to ASCII: LATIN1, LATIN2, LATIN9 and WINDOWS1250. So, it is + * better to skip this error at all. + */ + errors.add(Pattern.compile("encoding conversion from \\w+ to ASCII not supported")); + + /* + * In accordance with PostgreSQL code, commit 0ab1a2e, conversions to or from SQL_ASCII is meaningless. So + * disable errors on such an attempt. + */ + errors.add(Pattern.compile("encoding conversion from SQL_ASCII to \\w+ not supported")); + errors.add(Pattern.compile("encoding conversion from \\w+ to SQL_ASCII not supported")); + + return errors; } - private static void addCommonRegexExpressionErrors(ExpectedErrors errors) { + private static List getCommonRegexExpressionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("is not a valid hexadecimal digit"); + + return errors; } - public static void addCommonRangeExpressionErrors(ExpectedErrors errors) { + public static List getCommonRangeExpressionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("range lower bound must be less than or equal to range upper bound"); errors.add("result of range difference would not be contiguous"); errors.add("out of range"); errors.add("malformed range literal"); errors.add("result of range union would not be contiguous"); + + return errors; } - public static void addCommonInsertUpdateErrors(ExpectedErrors errors) { + public static void addCommonRangeExpressionErrors(ExpectedErrors errors) { + errors.addAll(getCommonRangeExpressionErrors()); + } + + public static List getCommonInsertUpdateErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("value too long for type character"); + errors.add("cannot insert a non-DEFAULT value into column"); errors.add("not found in view targetlist"); + + return errors; + } + + public static void addCommonInsertUpdateErrors(ExpectedErrors errors) { + errors.addAll(getCommonInsertUpdateErrors()); + } + + public static List getGroupingErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("non-integer constant in GROUP BY"); // TODO + errors.add("must appear in the GROUP BY clause or be used in an aggregate function"); + errors.add("is not in select list"); + errors.add("aggregate functions are not allowed in GROUP BY"); + + return errors; + } + + public static void addGroupingErrors(ExpectedErrors errors) { + errors.addAll(getGroupingErrors()); } public static boolean appendDataType(PostgresDataType type, StringBuilder sb, boolean allowSerial, @@ -334,6 +438,8 @@ private static void addTableConstraint(StringBuilder sb, PostgresTable table, Po } break; case EXCLUDE: + errors.add("exclusion constraints are not supported on partitioned tables"); + errors.add("unsupported EXCLUDE constraint with partition key definition"); sb.append("EXCLUDE "); sb.append("("); // TODO [USING index_method ] @@ -354,7 +460,6 @@ private static void addTableConstraint(StringBuilder sb, PostgresTable table, Po errors.add("exclusion constraints are not supported on partitioned tables"); errors.add("The exclusion operator must be related to the index operator class for the constraint"); errors.add("could not create exclusion constraint"); - // TODO: index parameters if (Randomly.getBoolean()) { sb.append(" WHERE "); sb.append("("); @@ -409,12 +514,4 @@ private static void appendExcludeElement(StringBuilder sb, PostgresGlobalState g private static void deleteOrUpdateAction(StringBuilder sb) { sb.append(Randomly.fromOptions("NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT")); } - - public static void addGroupingErrors(ExpectedErrors errors) { - errors.add("non-integer constant in GROUP BY"); // TODO - errors.add("must appear in the GROUP BY clause or be used in an aggregate function"); - errors.add("is not in select list"); - errors.add("aggregate functions are not allowed in GROUP BY"); - } - } diff --git a/src/sqlancer/postgres/gen/PostgresDiscardGenerator.java b/src/sqlancer/postgres/gen/PostgresDiscardGenerator.java index c977fa272..0c02efcd0 100644 --- a/src/sqlancer/postgres/gen/PostgresDiscardGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresDiscardGenerator.java @@ -25,6 +25,7 @@ public static SQLQueryAdapter create(PostgresGlobalState globalState) { } sb.append(what); return new SQLQueryAdapter(sb.toString(), ExpectedErrors.from("cannot run inside a transaction block")) { + private static final long serialVersionUID = 1L; @Override public boolean couldAffectSchema() { diff --git a/src/sqlancer/postgres/gen/PostgresExplainGenerator.java b/src/sqlancer/postgres/gen/PostgresExplainGenerator.java new file mode 100644 index 000000000..e4359e5aa --- /dev/null +++ b/src/sqlancer/postgres/gen/PostgresExplainGenerator.java @@ -0,0 +1,82 @@ +package sqlancer.postgres.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.PostgresSchema; +import sqlancer.postgres.PostgresSchema.PostgresDataType; +import sqlancer.postgres.PostgresSchema.PostgresTables; +import sqlancer.postgres.ast.PostgresSelect; + +public final class PostgresExplainGenerator { + + private PostgresExplainGenerator() { + + } + + public static String explain(String selectStr) { + StringBuilder sb = new StringBuilder(); + sb.append("EXPLAIN (FORMAT JSON) "); + sb.append(selectStr); + return sb.toString(); + } + + public static String explainGeneral(String selectStr) { + StringBuilder sb = new StringBuilder(); + sb.append("EXPLAIN "); + + List options = new ArrayList<>(); + boolean analyze = Randomly.getBoolean(); + boolean genericPlan = !analyze && Randomly.getBoolean(); + if (analyze) { + options.add("ANALYZE"); + } + if (genericPlan) { + options.add("GENERIC_PLAN"); + } + if (Randomly.getBoolean()) { + options.add("FORMAT " + Randomly.fromOptions("TEXT", "XML", "JSON", "YAML")); + } + if (Randomly.getBoolean()) { + options.add("VERBOSE"); + } + if (Randomly.getBoolean()) { + options.add("COSTS"); + } + if (analyze && Randomly.getBoolean()) { + options.add("BUFFERS"); + } + if (analyze && Randomly.getBoolean()) { + options.add("TIMING"); + } + if (Randomly.getBoolean()) { + options.add("SUMMARY"); + } + if (!options.isEmpty()) { + sb.append("("); + sb.append(String.join(", ", options)); + sb.append(") "); + } + + sb.append(selectStr); + return sb.toString(); + } + + public static SQLQueryAdapter create(PostgresGlobalState globalState) throws Exception { + PostgresSchema.PostgresTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + PostgresExpressionGenerator gen = new PostgresExpressionGenerator(globalState); + gen.setTablesAndColumns(new PostgresTables(Arrays.asList(table))); + PostgresSelect select = gen.generateSelect(); + select.setFromList(gen.getTableRefs()); + select.setFetchColumns(gen.generateFetchColumns(false)); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(PostgresDataType.BOOLEAN)); + } + return new SQLQueryAdapter(explainGeneral(select.asString())); + } + +} diff --git a/src/sqlancer/postgres/gen/PostgresExpressionGenerator.java b/src/sqlancer/postgres/gen/PostgresExpressionGenerator.java index f979e97da..bad87affc 100644 --- a/src/sqlancer/postgres/gen/PostgresExpressionGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresExpressionGenerator.java @@ -5,18 +5,25 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; import sqlancer.IgnoreMeException; import sqlancer.Randomly; +import sqlancer.common.gen.CERTGenerator; import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.postgres.PostgresBugs; import sqlancer.postgres.PostgresCompoundDataType; import sqlancer.postgres.PostgresGlobalState; import sqlancer.postgres.PostgresProvider; import sqlancer.postgres.PostgresSchema.PostgresColumn; import sqlancer.postgres.PostgresSchema.PostgresDataType; import sqlancer.postgres.PostgresSchema.PostgresRowValue; +import sqlancer.postgres.PostgresSchema.PostgresTable; +import sqlancer.postgres.PostgresSchema.PostgresTables; import sqlancer.postgres.ast.PostgresAggregate; import sqlancer.postgres.ast.PostgresAggregate.PostgresAggregateFunction; import sqlancer.postgres.ast.PostgresBetweenOperation; @@ -40,18 +47,32 @@ import sqlancer.postgres.ast.PostgresFunction.PostgresFunctionWithResult; import sqlancer.postgres.ast.PostgresFunctionWithUnknownResult; import sqlancer.postgres.ast.PostgresInOperation; +import sqlancer.postgres.ast.PostgresJoin; +import sqlancer.postgres.ast.PostgresJoin.PostgresJoinType; import sqlancer.postgres.ast.PostgresLikeOperation; import sqlancer.postgres.ast.PostgresOrderByTerm; -import sqlancer.postgres.ast.PostgresOrderByTerm.PostgresOrder; import sqlancer.postgres.ast.PostgresPOSIXRegularExpression; import sqlancer.postgres.ast.PostgresPOSIXRegularExpression.POSIXRegex; import sqlancer.postgres.ast.PostgresPostfixOperation; import sqlancer.postgres.ast.PostgresPostfixOperation.PostfixOperator; +import sqlancer.postgres.ast.PostgresPostfixText; import sqlancer.postgres.ast.PostgresPrefixOperation; import sqlancer.postgres.ast.PostgresPrefixOperation.PrefixOperator; +import sqlancer.postgres.ast.PostgresSelect; +import sqlancer.postgres.ast.PostgresSelect.ForClause; +import sqlancer.postgres.ast.PostgresSelect.PostgresFromTable; +import sqlancer.postgres.ast.PostgresSelect.PostgresSubquery; +import sqlancer.postgres.ast.PostgresSelect.SelectType; import sqlancer.postgres.ast.PostgresSimilarTo; +import sqlancer.postgres.ast.PostgresTableReference; +import sqlancer.postgres.ast.PostgresWindowFunction; +import sqlancer.postgres.ast.PostgresWindowFunction.WindowFrame; +import sqlancer.postgres.ast.PostgresWindowFunction.WindowSpecification; -public class PostgresExpressionGenerator implements ExpressionGenerator { +public class PostgresExpressionGenerator implements ExpressionGenerator, + NoRECGenerator, + TLPWhereGenerator, + CERTGenerator { private final int maxDepth; @@ -59,6 +80,8 @@ public class PostgresExpressionGenerator implements ExpressionGenerator columns; + private List targetTables; + private PostgresRowValue rw; private boolean expectedResult; @@ -93,11 +116,12 @@ public PostgresExpression generateExpression(int depth) { return generateExpression(depth, PostgresDataType.getRandomType()); } - public List generateOrderBy() { + @Override + public List generateOrderBys() { List orderBys = new ArrayList<>(); for (int i = 0; i < Randomly.smallNumber(); i++) { - orderBys.add(new PostgresOrderByTerm(PostgresColumnValue.create(Randomly.fromList(columns), null), - PostgresOrder.getRandomOrder())); + PostgresExpression expr = PostgresColumnValue.create(Randomly.fromList(columns), null); + orderBys.add(expr); } return orderBys; } @@ -206,8 +230,7 @@ private PostgresExpression generateBooleanExpression(int depth) { } private PostgresDataType getMeaningfulType() { - // make it more likely that the expression does not only consist of constant - // expressions + // make it more likely that the expression does not only consist of constant expressions if (Randomly.getBooleanWithSmallProbability() || columns == null || columns.isEmpty()) { return PostgresDataType.getRandomType(); } else { @@ -400,6 +423,66 @@ private PostgresExpression generateTextExpression(int depth) { } } + public PostgresExpression generateWindowFunction(int depth, PostgresDataType returnType) { + List arguments = generateWindowFunctionArguments(depth); + List partitionBy = generatePartitionByExpressions(depth); + List orderBy = generateOrderByExpressions(depth); + WindowFrame frame = generateWindowFrame(); + + WindowSpecification windowSpec = new WindowSpecification(partitionBy, orderBy, frame); + String functionName = selectWindowFunctionName(); + + return new PostgresWindowFunction(functionName, arguments, windowSpec, returnType); + } + + private List generateWindowFunctionArguments(int depth) { + List arguments = new ArrayList<>(); + if (Randomly.getBoolean()) { + arguments.add(generateExpression(depth + 1)); + } + return arguments; + } + + private List generatePartitionByExpressions(int depth) { + List partitionBy = new ArrayList<>(); + if (Randomly.getBoolean()) { + int count = Randomly.smallNumber(); + for (int i = 0; i < count; i++) { + partitionBy.add(generateExpression(depth + 1)); + } + } + return partitionBy; + } + + private List generateOrderByExpressions(int depth) { + List orderBy = new ArrayList<>(); + if (Randomly.getBoolean()) { + int count = Randomly.smallNumber(); + for (int i = 0; i < count; i++) { + PostgresExpression expr = generateExpression(depth + 1); + // Call the second constructor in PostgresOrderByTerm, might be removed in the future to have only one + // constructor + orderBy.add(new PostgresOrderByTerm(expr, Randomly.getBoolean())); + } + } + return orderBy; + } + + private WindowFrame generateWindowFrame() { + if (Randomly.getBoolean()) { + WindowFrame.FrameType frameType = Randomly.fromOptions(WindowFrame.FrameType.values()); + PostgresExpression startExpr = generateConstant(globalState.getRandomly(), PostgresDataType.INT); + PostgresExpression endExpr = generateConstant(globalState.getRandomly(), PostgresDataType.INT); + return new WindowFrame(frameType, startExpr, endExpr); + } + return null; + } + + private String selectWindowFunctionName() { + return Randomly.fromList(Arrays.asList("row_number", "rank", "dense_rank", "percent_rank", "cume_dist", "ntile", + "lag", "lead", "first_value", "last_value", "nth_value")); + } + private PostgresExpression generateConcat(int depth) { PostgresExpression left = generateExpression(depth + 1, PostgresDataType.TEXT); PostgresExpression right = generateExpression(depth + 1); @@ -423,6 +506,7 @@ private PostgresExpression generateBitExpression(int depth) { } } + // Removed WINDOW_FUNCTION option from the integer expression generation. private enum IntExpression { UNARY_OPERATION, FUNCTION, CAST, BINARY_ARITHMETIC_EXPRESSION } @@ -586,6 +670,35 @@ public PostgresExpressionGenerator allowAggregates(boolean value) { return this; } + public static PostgresSubquery createSubquery(PostgresGlobalState globalState, String name, PostgresTables tables) { + List columns = new ArrayList<>(); + PostgresExpressionGenerator gen = new PostgresExpressionGenerator(globalState).setColumns(tables.getColumns()); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + columns.add(gen.generateExpression(0)); + } + PostgresSelect select = new PostgresSelect(); + select.setFromList(tables.getTables().stream().map(t -> new PostgresFromTable(t, Randomly.getBoolean())) + .collect(Collectors.toList())); + select.setFetchColumns(columns); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(0, PostgresDataType.BOOLEAN)); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + if (Randomly.getBoolean()) { + select.setLimitClause(PostgresConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + if (Randomly.getBoolean()) { + select.setOffsetClause( + PostgresConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + } + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setForClause(ForClause.getRandom()); + } + return new PostgresSubquery(select, name); + } + @Override public PostgresExpression generatePredicate() { return generateExpression(PostgresDataType.BOOLEAN); @@ -601,4 +714,262 @@ public PostgresExpression isNull(PostgresExpression expr) { return new PostgresPostfixOperation(expr, PostfixOperator.IS_NULL); } + @Override + public PostgresExpressionGenerator setTablesAndColumns( + sqlancer.common.schema.AbstractTables targetTables) { + this.targetTables = targetTables.getTables(); + this.columns = targetTables.getColumns(); + return this; + } + + @Override + public PostgresExpression generateBooleanExpression() { + return generateExpression(PostgresDataType.BOOLEAN); + } + + @Override + public PostgresSelect generateSelect() { + PostgresSelect select = new PostgresSelect(); + + if (Randomly.getBooleanWithRatherLowProbability()) { + List windowFunctions = generateWindowFunctions(); + select.setWindowFunctions(windowFunctions); + } + + return select; + } + + private List generateWindowFunctions() { + List windowFunctions = new ArrayList<>(); + int numWindowFunctions = Randomly.smallNumber(); + for (int i = 0; i < numWindowFunctions; i++) { + windowFunctions.add(generateWindowFunction(0, + Randomly.fromList(Arrays.asList(PostgresDataType.INT, PostgresDataType.FLOAT)))); + } + return windowFunctions; + } + + @Override + public List getRandomJoinClauses() { + List joinStatements = new ArrayList<>(); + for (int i = 1; i < targetTables.size(); i++) { + PostgresExpression joinClause = generateExpression(PostgresDataType.BOOLEAN); + PostgresTable table = Randomly.fromList(targetTables); + targetTables.remove(table); + PostgresJoinType options = PostgresJoinType.getRandom(); + PostgresJoin j = new PostgresJoin(new PostgresFromTable(table, Randomly.getBoolean()), joinClause, options); + joinStatements.add(j); + } + // JOIN subqueries + for (int i = 0; i < Randomly.smallNumber(); i++) { + PostgresTables subqueryTables = globalState.getSchema().getRandomTableNonEmptyTables(); + PostgresSubquery subquery = createSubquery(globalState, String.format("sub%d", i), subqueryTables); + PostgresExpression joinClause = generateExpression(PostgresDataType.BOOLEAN); + PostgresJoinType options = PostgresJoinType.getRandom(); + PostgresJoin j = new PostgresJoin(subquery, joinClause, options); + joinStatements.add(j); + } + return joinStatements; + } + + @Override + public List getTableRefs() { + return targetTables.stream().map(t -> new PostgresFromTable(t, Randomly.getBoolean())) + .collect(Collectors.toList()); + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy && Randomly.getBooleanWithRatherLowProbability()) { + return Arrays.asList(new PostgresColumnValue(PostgresColumn.createDummy("*"), null)); + } + allowAggregateFunctions = true; + List fetchColumns = new ArrayList<>(); + List targetColumns = Randomly.nonEmptySubset(columns); + for (PostgresColumn c : targetColumns) { + fetchColumns.add(new PostgresColumnValue(c, null)); + } + allowAggregateFunctions = false; + return fetchColumns; + } + + @Override + public String generateOptimizedQueryString(PostgresSelect select, PostgresExpression whereCondition, + boolean shouldUseAggregate) { + PostgresColumnValue allColumns = new PostgresColumnValue(PostgresColumn.createDummy("*"), null); + if (shouldUseAggregate) { + select.setFetchColumns( + Arrays.asList(new PostgresAggregate(List.of(allColumns), PostgresAggregateFunction.COUNT))); + } else { + select.setFetchColumns(Arrays.asList(allColumns)); + } + select.setWhereClause(whereCondition); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(generateOrderBys()); + } + select.setSelectType(SelectType.ALL); + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(PostgresSelect select, PostgresExpression whereCondition) { + PostgresCastOperation isTrue = new PostgresCastOperation(whereCondition, + PostgresCompoundDataType.create(PostgresDataType.INT)); + PostgresPostfixText asText = new PostgresPostfixText(isTrue, " as count", null, PostgresDataType.INT); + select.setFetchColumns(Arrays.asList(asText)); + select.setWhereClause(null); + select.setOrderByClauses(List.of()); + select.setSelectType(SelectType.ALL); + + return "SELECT SUM(count) FROM (" + select.asString() + ") as res"; + } + + @Override + public String generateExplainQuery(PostgresSelect select) { + return "EXPLAIN " + select.asString(); + } + + @Override + public boolean mutate(PostgresSelect select) { + List> mutators = new ArrayList<>(); + + mutators.add(this::mutateJoin); + mutators.add(this::mutateWhere); + mutators.add(this::mutateGroupBy); + mutators.add(this::mutateHaving); + mutators.add(this::mutateWindowFunction); + if (!PostgresBugs.bug18643) { + mutators.add(this::mutateAnd); + mutators.add(this::mutateOr); + } + mutators.add(this::mutateDistinct); + + return Randomly.fromList(mutators).apply(select); + } + + private boolean mutateWindowFunction(PostgresSelect select) { + List windowFunctions = select.getWindowFunctions(); + if (windowFunctions == null || windowFunctions.isEmpty()) { + windowFunctions = new ArrayList<>(); + windowFunctions.add(generateWindowFunction(0, PostgresDataType.INT)); + select.setWindowFunctions(windowFunctions); + return false; + } else { + windowFunctions.remove(Randomly.fromList(windowFunctions)); + if (windowFunctions.isEmpty()) { + select.setWindowFunctions(null); + } + return true; + } + } + + boolean mutateJoin(PostgresSelect select) { + if (select.getJoinList().isEmpty()) { + return false; + } + PostgresJoin join = (PostgresJoin) Randomly.fromList(select.getJoinList()); + + // Exclude CROSS for on condition + if (join.getType() == PostgresJoinType.CROSS) { + List columns = new ArrayList<>(); + columns.addAll(((PostgresTableReference) join.getLeftTable()).getTable().getColumns()); + columns.addAll(((PostgresTableReference) join.getRightTable()).getTable().getColumns()); + PostgresExpressionGenerator joinGen2 = new PostgresExpressionGenerator(globalState).setColumns(columns); + join.setOnClause(joinGen2.generateExpression(0, PostgresDataType.BOOLEAN)); + } + + PostgresJoinType newJoinType = PostgresJoinType.INNER; + if (join.getType() == PostgresJoinType.LEFT || join.getType() == PostgresJoinType.RIGHT) { + newJoinType = PostgresJoinType.getRandomExcept(PostgresJoinType.LEFT, PostgresJoinType.RIGHT); + } else { + newJoinType = PostgresJoinType.getRandomExcept(join.getType()); + } + boolean increase = join.getType().ordinal() < newJoinType.ordinal(); + join.setType(newJoinType); + if (newJoinType == PostgresJoinType.CROSS) { + join.setOnClause(null); + } + return increase; + } + + boolean mutateDistinct(PostgresSelect select) { + PostgresSelect.SelectType selectType = select.getSelectOption(); + if (selectType != PostgresSelect.SelectType.ALL) { + select.setSelectType(PostgresSelect.SelectType.ALL); + return true; + } else { + select.setSelectType(PostgresSelect.SelectType.DISTINCT); + return false; + } + } + + boolean mutateWhere(PostgresSelect select) { + boolean increase = select.getWhereClause() != null; + if (increase) { + select.setWhereClause(null); + } else { + select.setWhereClause(generateExpression(0, PostgresDataType.BOOLEAN)); + } + return increase; + } + + boolean mutateGroupBy(PostgresSelect select) { + boolean increase = !select.getGroupByExpressions().isEmpty(); + if (increase) { + select.clearGroupByExpressions(); + } else { + select.setGroupByExpressions(select.getFetchColumns()); + } + return increase; + } + + boolean mutateHaving(PostgresSelect select) { + if (select.getGroupByExpressions().isEmpty()) { + select.setGroupByExpressions(select.getFetchColumns()); + select.setHavingClause(generateExpression(0, PostgresDataType.BOOLEAN)); + return false; + } else { + if (select.getHavingClause() == null) { + select.setHavingClause(generateExpression(0, PostgresDataType.BOOLEAN)); + return false; + } else { + select.setHavingClause(null); + return true; + } + } + } + + boolean mutateAnd(PostgresSelect select) { + if (select.getWhereClause() == null) { + select.setWhereClause(generateExpression(0, PostgresDataType.BOOLEAN)); + } else { + PostgresExpression newWhere = new PostgresBinaryLogicalOperation(select.getWhereClause(), + generateExpression(0, PostgresDataType.BOOLEAN), BinaryLogicalOperator.AND); + select.setWhereClause(newWhere); + } + return false; + } + + boolean mutateOr(PostgresSelect select) { + if (select.getWhereClause() == null) { + select.setWhereClause(generateExpression(0, PostgresDataType.BOOLEAN)); + return false; + } else { + PostgresExpression newWhere = new PostgresBinaryLogicalOperation(select.getWhereClause(), + generateExpression(0, PostgresDataType.BOOLEAN), BinaryLogicalOperator.OR); + select.setWhereClause(newWhere); + return true; + } + } + + boolean mutateLimit(PostgresSelect select) { + boolean increase = select.getLimitClause() != null; + if (increase) { + select.setLimitClause(null); + } else { + Randomly r = new Randomly(); + select.setLimitClause(PostgresConstant.createIntConstant((int) Math.abs(r.getInteger()))); + } + return increase; + } } diff --git a/src/sqlancer/postgres/gen/PostgresInsertGenerator.java b/src/sqlancer/postgres/gen/PostgresInsertGenerator.java index 41c017b2e..49b94b184 100644 --- a/src/sqlancer/postgres/gen/PostgresInsertGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresInsertGenerator.java @@ -7,6 +7,7 @@ import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.PostgresSchema; import sqlancer.postgres.PostgresSchema.PostgresColumn; import sqlancer.postgres.PostgresSchema.PostgresTable; import sqlancer.postgres.PostgresVisitor; @@ -19,6 +20,10 @@ private PostgresInsertGenerator() { public static SQLQueryAdapter insert(PostgresGlobalState globalState) { PostgresTable table = globalState.getSchema().getRandomTable(t -> t.isInsertable()); + return insertRows(globalState, table); + } + + public static SQLQueryAdapter insertRows(PostgresGlobalState globalState, PostgresSchema.PostgresTable table) { ExpectedErrors errors = new ExpectedErrors(); errors.add("cannot insert into column"); PostgresCommon.addCommonExpressionErrors(errors); diff --git a/src/sqlancer/postgres/gen/PostgresRandomQueryGenerator.java b/src/sqlancer/postgres/gen/PostgresRandomQueryGenerator.java index d1527ca8b..64d6278de 100644 --- a/src/sqlancer/postgres/gen/PostgresRandomQueryGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresRandomQueryGenerator.java @@ -45,7 +45,7 @@ public static PostgresSelect createRandomQuery(int nrColumns, PostgresGlobalStat } } if (Randomly.getBooleanWithRatherLowProbability()) { - select.setOrderByExpressions(gen.generateOrderBy()); + select.setOrderByClauses(gen.generateOrderBys()); } if (Randomly.getBoolean()) { select.setLimitClause(PostgresConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); diff --git a/src/sqlancer/postgres/gen/PostgresReindexGenerator.java b/src/sqlancer/postgres/gen/PostgresReindexGenerator.java index 9bb5ec5cd..d22ffe53e 100644 --- a/src/sqlancer/postgres/gen/PostgresReindexGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresReindexGenerator.java @@ -1,7 +1,6 @@ package sqlancer.postgres.gen; import java.util.List; -import java.util.stream.Collectors; import sqlancer.IgnoreMeException; import sqlancer.Randomly; @@ -39,7 +38,7 @@ public static SQLQueryAdapter create(PostgresGlobalState globalState) { if (indexes.isEmpty()) { throw new IgnoreMeException(); } - sb.append(indexes.stream().map(i -> i.getIndexName()).collect(Collectors.joining())); + sb.append(Randomly.fromList(indexes).getIndexName()); break; case TABLE: sb.append("TABLE "); @@ -59,7 +58,6 @@ public static SQLQueryAdapter create(PostgresGlobalState globalState) { throw new AssertionError(scope); } errors.add("already contains data"); // FIXME bug report - errors.add("does not exist"); // internal index errors.add("REINDEX is not yet implemented for partitioned indexes"); return new SQLQueryAdapter(sb.toString(), errors); } diff --git a/src/sqlancer/postgres/gen/PostgresSetGenerator.java b/src/sqlancer/postgres/gen/PostgresSetGenerator.java index f600f1f04..38318b26b 100644 --- a/src/sqlancer/postgres/gen/PostgresSetGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresSetGenerator.java @@ -14,7 +14,7 @@ private PostgresSetGenerator() { } private enum ConfigurationOption { - // https://www.postgresql.org/docs/11/runtime-config-wal.html + // https://www.postgresql.org/docs/13/runtime-config-wal.html // This parameter can only be set at server start. // WAL_LEVEL("wal_level", (r) -> Randomly.fromOptions("replica", "minimal", "logical")), // FSYNC("fsync", (r) -> Randomly.fromOptions(1, 0)), @@ -37,7 +37,7 @@ private enum ConfigurationOption { // archive_mode // archive_command // archive_timeout - // https://www.postgresql.org/docs/11/runtime-config-statistics.html + // https://www.postgresql.org/docs/13/runtime-config-statistics.html // 19.9.1. Query and Index Statistics Collector TRACK_ACTIVITIES("track_activities", (r) -> Randomly.fromOptions(1, 0)), // track_activity_query_size @@ -46,7 +46,7 @@ private enum ConfigurationOption { TRACK_FUNCTIONS("track_functions", (r) -> Randomly.fromOptions("'none'", "'pl'", "'all'")), // stats_temp_directory // TODO 19.9.2. Statistics Monitoring - // https://www.postgresql.org/docs/11/runtime-config-autovacuum.html + // https://www.postgresql.org/docs/13/runtime-config-autovacuum.html // all can only be set at server-conf time // 19.11. Client Connection Defaults VACUUM_FREEZE_TABLE_AGE("vacuum_freeze_table_age", (r) -> Randomly.fromOptions(0, 5, 10, 100, 500, 2000000000)), @@ -60,7 +60,7 @@ private enum ConfigurationOption { // 19.13. Version and Platform Compatibility DEFAULT_WITH_OIDS("default_with_oids", (r) -> Randomly.fromOptions(0, 1)), SYNCHRONIZED_SEQSCANS("synchronize_seqscans", (r) -> Randomly.fromOptions(0, 1)), - // https://www.postgresql.org/docs/devel/runtime-config-query.html + // https://www.postgresql.org/docs/13/runtime-config-query.html ENABLE_BITMAPSCAN("enable_bitmapscan", (r) -> Randomly.fromOptions(1, 0)), ENABLE_GATHERMERGE("enable_gathermerge", (r) -> Randomly.fromOptions(1, 0)), ENABLE_HASHJOIN("enable_hashjoin", (r) -> Randomly.fromOptions(1, 0)), @@ -78,7 +78,7 @@ private enum ConfigurationOption { ENABLE_SORT("enable_sort", (r) -> Randomly.fromOptions(1, 0)), ENABLE_TIDSCAN("enable_tidscan", (r) -> Randomly.fromOptions(1, 0)), // 19.7.2. Planner Cost Constants (complete as of March 2020) - // https://www.postgresql.org/docs/current/runtime-config-query.html#RUNTIME-CONFIG-QUERY-CONSTANTS + // https://www.postgresql.org/docs/13/runtime-config-query.html#RUNTIME-CONFIG-QUERY-CONSTANTS SEQ_PAGE_COST("seq_page_cost", (r) -> Randomly.fromOptions(0d, 0.00001, 0.05, 0.1, 1, 10, 10000)), RANDOM_PAGE_COST("random_page_cost", (r) -> Randomly.fromOptions(0d, 0.00001, 0.05, 0.1, 1, 10, 10000)), CPU_TUPLE_COST("cpu_tuple_cost", (r) -> Randomly.fromOptions(0d, 0.00001, 0.05, 0.1, 1, 10, 10000)), @@ -94,7 +94,7 @@ private enum ConfigurationOption { JIT_OPTIMIZE_ABOVE_COST("jit_optimize_above_cost", (r) -> Randomly.fromOptions(0, r.getLong(-1, Long.MAX_VALUE))), // 19.7.3. Genetic Query Optimizer (complete as of March 2020) - // https://www.postgresql.org/docs/current/runtime-config-query.html#RUNTIME-CONFIG-QUERY-GEQO + // https://www.postgresql.org/docs/13/runtime-config-query.html#RUNTIME-CONFIG-QUERY-GEQO GEQO("geqo", (r) -> Randomly.fromOptions(1, 0)), GEQO_THRESHOLD("geqo_threshold", (r) -> r.getInteger(2, 2147483647)), GEQO_EFFORT("geqo_effort", (r) -> r.getInteger(1, 10)), @@ -103,7 +103,7 @@ private enum ConfigurationOption { GEQO_SELECTION_BIAS("geqo_selection_bias", (r) -> Randomly.fromOptions(1.5, 1.8, 2.0)), GEQO_SEED("geqo_seed", (r) -> Randomly.fromOptions(0, 0.5, 1)), // 19.7.4. Other Planner Options (complete as of March 2020) - // https://www.postgresql.org/docs/current/runtime-config-query.html#RUNTIME-CONFIG-QUERY-OTHER + // https://www.postgresql.org/docs/13/runtime-config-query.html#RUNTIME-CONFIG-QUERY-OTHER DEFAULT_STATISTICS_TARGET("default_statistics_target", (r) -> r.getInteger(1, 10000)), CONSTRAINT_EXCLUSION("constraint_exclusion", (r) -> Randomly.fromOptions("on", "off", "partition")), CURSOR_TUPLE_FRACTION("cursor_tuple_fraction", @@ -112,7 +112,7 @@ private enum ConfigurationOption { JIT("jit", (r) -> Randomly.fromOptions(1, 0)), JOIN_COLLAPSE_LIMIT("join_collapse_limit", (r) -> r.getInteger(1, Integer.MAX_VALUE)), PARALLEL_LEADER_PARTICIPATION("parallel_leader_participation", (r) -> Randomly.fromOptions(1, 0)), - FORCE_PARALLEL_MODE("force_parallel_mode", (r) -> Randomly.fromOptions("off", "on", "regress")), + // FORCE_PARALLEL_MODE("force_parallel_mode", (r) -> Randomly.fromOptions("off", "on", "regress")), PLAN_CACHE_MODE("plan_cache_mode", (r) -> Randomly.fromOptions("auto", "force_generic_plan", "force_custom_plan")); diff --git a/src/sqlancer/postgres/gen/PostgresStatisticsGenerator.java b/src/sqlancer/postgres/gen/PostgresStatisticsGenerator.java index e27f82800..9a82e29af 100644 --- a/src/sqlancer/postgres/gen/PostgresStatisticsGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresStatisticsGenerator.java @@ -58,6 +58,20 @@ public static SQLQueryAdapter remove(PostgresGlobalState globalState) { return new SQLQueryAdapter(sb.toString(), true); } + public static SQLQueryAdapter alter(PostgresGlobalState globalState) { + StringBuilder sb = new StringBuilder("ALTER STATISTICS "); + PostgresTable randomTable = globalState.getSchema().getRandomTable(); + List statistics = randomTable.getStatistics(); + if (statistics.isEmpty()) { + throw new IgnoreMeException(); + } + PostgresStatisticsObject randomStatistic = Randomly.fromList(statistics); + sb.append(randomStatistic.getName()); + sb.append(" SET STATISTICS "); + sb.append(Randomly.getNotCachedInteger(-1, 10000)); // -1 means default + return new SQLQueryAdapter(sb.toString(), true); + } + private static String getNewStatisticsName(PostgresTable randomTable) { List statistics = randomTable.getStatistics(); int i = 0; diff --git a/src/sqlancer/postgres/gen/PostgresTableGenerator.java b/src/sqlancer/postgres/gen/PostgresTableGenerator.java index 7d65b8ebd..29ccfcf2c 100644 --- a/src/sqlancer/postgres/gen/PostgresTableGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresTableGenerator.java @@ -113,7 +113,9 @@ private void createStandard() throws AssertionError { generateInherits(); generatePartitionBy(); generateUsing(); - PostgresCommon.generateWith(sb, globalState, errors); + if (!isPartitionedTable) { + PostgresCommon.generateWith(sb, globalState, errors); + } if (Randomly.getBoolean() && isTemporaryTable) { sb.append(" ON COMMIT "); sb.append(Randomly.fromOptions("PRESERVE ROWS", "DELETE ROWS", "DROP")); @@ -205,7 +207,7 @@ private void generateUsing() { } private void generateInherits() { - if (Randomly.getBoolean() && !newSchema.getDatabaseTables().isEmpty()) { + if (Randomly.getBoolean() && !newSchema.getDatabaseTablesWithoutViews().isEmpty()) { sb.append(" INHERITS("); sb.append(newSchema.getDatabaseTablesRandomSubsetNotEmpty().stream().map(t -> t.getName()) .collect(Collectors.joining(", "))); diff --git a/src/sqlancer/postgres/gen/PostgresTableSpaceGenerator.java b/src/sqlancer/postgres/gen/PostgresTableSpaceGenerator.java new file mode 100644 index 000000000..3890d5160 --- /dev/null +++ b/src/sqlancer/postgres/gen/PostgresTableSpaceGenerator.java @@ -0,0 +1,55 @@ +package sqlancer.postgres.gen; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.PostgresOptions; + +public class PostgresTableSpaceGenerator { + + private final ExpectedErrors errors = new ExpectedErrors(); + private final PostgresGlobalState globalState; + + public PostgresTableSpaceGenerator(PostgresGlobalState globalState) { + this.globalState = globalState; + errors.addRegexString("ERROR: (?:tablespace )?directory \".*[\\\\/]tablespace\\d+\" does not exist"); + errors.add("ERROR: already exists"); + errors.add("ERROR: is not empty"); + errors.add("ERROR: cannot be created because system does not support tablespaces"); + } + + public static SQLQueryAdapter generate(PostgresGlobalState globalState) { + // Skip tablespace generation if the option is disabled + PostgresOptions options = globalState.getDbmsSpecificOptions(); + if (!options.isTestTablespaces()) { + return null; + } + return new PostgresTableSpaceGenerator(globalState).generateTableSpace(); + } + + private SQLQueryAdapter generateTableSpace() { + StringBuilder sb = new StringBuilder(); + int tableSpaceNum = globalState.getRandomly().getInteger(1, Integer.MAX_VALUE); + + // CREATE TABLESPACE syntax + sb.append("CREATE TABLESPACE "); + sb.append("tablespace"); + sb.append(tableSpaceNum); + sb.append(" LOCATION '"); + + // Get the validated base path from options and append the tablespace number + PostgresOptions options = globalState.getDbmsSpecificOptions(); + String path = options.getTablespacePath() + tableSpaceNum; + + // Convert backslashes to forward slashes for PostgreSQL + path = path.replace('\\', '/'); + + // Escape single quotes in the path + path = path.replace("'", "''"); + + sb.append(path); + sb.append("'"); + + return new SQLQueryAdapter(sb.toString(), errors); + } +} diff --git a/src/sqlancer/postgres/gen/PostgresTruncateGenerator.java b/src/sqlancer/postgres/gen/PostgresTruncateGenerator.java index 197668ac5..0745d1cce 100644 --- a/src/sqlancer/postgres/gen/PostgresTruncateGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresTruncateGenerator.java @@ -33,8 +33,9 @@ public static SQLQueryAdapter create(PostgresGlobalState globalState) { sb.append(" "); sb.append(Randomly.fromOptions("CASCADE", "RESTRICT")); } - return new SQLQueryAdapter(sb.toString(), ExpectedErrors - .from("cannot truncate a table referenced in a foreign key constraint", "is not a table")); + return new SQLQueryAdapter(sb.toString(), + ExpectedErrors.from("cannot truncate a table referenced in a foreign key constraint", "is not a table", + "is not distributed")); } } diff --git a/src/sqlancer/postgres/gen/PostgresUpdateGenerator.java b/src/sqlancer/postgres/gen/PostgresUpdateGenerator.java index 1122b2b5a..7ce7fe882 100644 --- a/src/sqlancer/postgres/gen/PostgresUpdateGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresUpdateGenerator.java @@ -1,9 +1,10 @@ package sqlancer.postgres.gen; +import java.util.Arrays; import java.util.List; import sqlancer.Randomly; -import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.gen.AbstractUpdateGenerator; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.postgres.PostgresGlobalState; import sqlancer.postgres.PostgresSchema.PostgresColumn; @@ -12,51 +13,36 @@ import sqlancer.postgres.PostgresVisitor; import sqlancer.postgres.ast.PostgresExpression; -public final class PostgresUpdateGenerator { +public final class PostgresUpdateGenerator extends AbstractUpdateGenerator { - private PostgresUpdateGenerator() { + private final PostgresGlobalState globalState; + private PostgresTable randomTable; + + private PostgresUpdateGenerator(PostgresGlobalState globalState) { + this.globalState = globalState; + errors.addAll(Arrays.asList("conflicting key value violates exclusion constraint", + "reached maximum value of sequence", "violates foreign key constraint", "violates not-null constraint", + "violates unique constraint", "out of range", "cannot cast", "must be type boolean", "is not unique", + " bit string too long", "can only be updated to DEFAULT", "division by zero", + "You might need to add explicit type casts.", "invalid regular expression", + "View columns that are not columns of their base relation are not updatable")); } public static SQLQueryAdapter create(PostgresGlobalState globalState) { - PostgresTable randomTable = globalState.getSchema().getRandomTable(t -> t.isInsertable()); - StringBuilder sb = new StringBuilder(); + return new PostgresUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + randomTable = globalState.getSchema().getRandomTable(t -> t.isInsertable()); + List columns = randomTable.getRandomNonEmptyColumnSubset(); sb.append("UPDATE "); sb.append(randomTable.getName()); sb.append(" SET "); - ExpectedErrors errors = ExpectedErrors.from("conflicting key value violates exclusion constraint", - "reached maximum value of sequence", "violates foreign key constraint", "violates not-null constraint", - "violates unique constraint", "out of range", "cannot cast", "must be type boolean", "is not unique", - " bit string too long", "can only be updated to DEFAULT", "division by zero", - "You might need to add explicit type casts.", "invalid regular expression", - "View columns that are not columns of their base relation are not updatable"); errors.add("multiple assignments to same column"); // view whose columns refer to a column in the referenced // table multiple times errors.add("new row violates check option for view"); - List columns = randomTable.getRandomNonEmptyColumnSubset(); PostgresCommon.addCommonInsertUpdateErrors(errors); - - for (int i = 0; i < columns.size(); i++) { - if (i != 0) { - sb.append(", "); - } - PostgresColumn column = columns.get(i); - sb.append(column.getName()); - sb.append(" = "); - if (!Randomly.getBoolean()) { - PostgresExpression constant = PostgresExpressionGenerator.generateConstant(globalState.getRandomly(), - column.getType()); - sb.append(PostgresVisitor.asString(constant)); - } else if (Randomly.getBoolean()) { - sb.append("DEFAULT"); - } else { - sb.append("("); - PostgresExpression expr = PostgresExpressionGenerator.generateExpression(globalState, - randomTable.getColumns(), column.getType()); - // caused by casts - sb.append(PostgresVisitor.asString(expr)); - sb.append(")"); - } - } + updateColumns(columns); errors.add("invalid input syntax for "); errors.add("operator does not exist: text = boolean"); errors.add("violates check constraint"); @@ -73,4 +59,22 @@ public static SQLQueryAdapter create(PostgresGlobalState globalState) { return new SQLQueryAdapter(sb.toString(), errors, true); } + @Override + protected void updateValue(PostgresColumn column) { + if (!Randomly.getBoolean()) { + PostgresExpression constant = PostgresExpressionGenerator.generateConstant(globalState.getRandomly(), + column.getType()); + sb.append(PostgresVisitor.asString(constant)); + } else if (Randomly.getBoolean()) { + sb.append("DEFAULT"); + } else { + sb.append("("); + PostgresExpression expr = PostgresExpressionGenerator.generateExpression(globalState, + randomTable.getColumns(), column.getType()); + // caused by casts + sb.append(PostgresVisitor.asString(expr)); + sb.append(")"); + } + } + } diff --git a/src/sqlancer/postgres/gen/PostgresViewGenerator.java b/src/sqlancer/postgres/gen/PostgresViewGenerator.java index b0a2a8b9d..10992ece6 100644 --- a/src/sqlancer/postgres/gen/PostgresViewGenerator.java +++ b/src/sqlancer/postgres/gen/PostgresViewGenerator.java @@ -35,19 +35,11 @@ public static SQLQueryAdapter create(PostgresGlobalState globalState) { materialized = false; } sb.append(" VIEW "); - int i = 0; - String[] name = new String[1]; - while (true) { - name[0] = "v" + i++; - if (globalState.getSchema().getDatabaseTables().stream() - .noneMatch(tab -> tab.getName().contentEquals(name[0]))) { - break; - } - } - sb.append(name[0]); + String name = globalState.getSchema().getFreeViewName(); + sb.append(name); sb.append("("); int nrColumns = Randomly.smallNumber() + 1; - for (i = 0; i < nrColumns; i++) { + for (int i = 0; i < nrColumns; i++) { if (i != 0) { sb.append(", "); } diff --git a/src/sqlancer/postgres/gen/PostgresWindowFunctionGenerator.java b/src/sqlancer/postgres/gen/PostgresWindowFunctionGenerator.java new file mode 100644 index 000000000..6daf66c3c --- /dev/null +++ b/src/sqlancer/postgres/gen/PostgresWindowFunctionGenerator.java @@ -0,0 +1,131 @@ +package sqlancer.postgres.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.PostgresSchema.PostgresDataType; +import sqlancer.postgres.ast.PostgresConstant; +import sqlancer.postgres.ast.PostgresExpression; +import sqlancer.postgres.ast.PostgresOrderByTerm; +import sqlancer.postgres.ast.PostgresWindowFunction; +import sqlancer.postgres.ast.PostgresWindowFunction.WindowFrame; +import sqlancer.postgres.ast.PostgresWindowFunction.WindowSpecification; + +public final class PostgresWindowFunctionGenerator { + + private static final List WINDOW_FUNCTIONS = Arrays.asList("row_number", "rank", "dense_rank", + "percent_rank", "cume_dist", "ntile", "lag", "lead", "first_value", "last_value", "nth_value"); + + private PostgresWindowFunctionGenerator() { + throw new AssertionError("Utility class should not be instantiated"); + } + + public static PostgresWindowFunction generateWindowFunction(PostgresGlobalState globalState, + List availableExpr) { + + String functionName = selectRandomWindowFunction(); + List arguments = generateFunctionArguments(functionName, globalState, availableExpr); + WindowSpecification windowSpec = generateWindowSpecification(globalState, availableExpr); + PostgresDataType returnType = determineReturnType(functionName); + + return new PostgresWindowFunction(functionName, arguments, windowSpec, returnType); + } + + private static String selectRandomWindowFunction() { + return Randomly.fromList(WINDOW_FUNCTIONS); + } + + private static List generateFunctionArguments(String functionName, + PostgresGlobalState globalState, List availableExpr) { + List arguments = new ArrayList<>(); + + switch (functionName) { + case "ntile": + arguments + .add(PostgresExpressionGenerator.generateConstant(globalState.getRandomly(), PostgresDataType.INT)); + break; + case "lag": + case "lead": + case "nth_value": + arguments.add(Randomly.fromList(availableExpr)); + if (Randomly.getBoolean()) { + arguments.add( + PostgresExpressionGenerator.generateConstant(globalState.getRandomly(), PostgresDataType.INT)); + } + break; + case "first_value": + case "last_value": + arguments.add(Randomly.fromList(availableExpr)); + break; + default: + // No arguments needed for other window functions + break; + } + + return arguments; + } + + private static WindowSpecification generateWindowSpecification(PostgresGlobalState globalState, + List availableExpr) { + List partitionBy = generatePartitionByClause(availableExpr); + PostgresExpressionGenerator exprGen = new PostgresExpressionGenerator(globalState); + List orderBys = exprGen.generateOrderBys(); + List orderByTerms = new ArrayList<>(); + for (PostgresExpression expr : orderBys) { + orderByTerms.add(new PostgresOrderByTerm(expr, Randomly.getBoolean())); + } + + WindowFrame frame = generateWindowFrame(globalState); + return new WindowSpecification(partitionBy, orderByTerms, frame); + } + + private static List generatePartitionByClause(List availableExpr) { + List partitionBy = new ArrayList<>(); + if (Randomly.getBooleanWithRatherLowProbability()) { + int count = Randomly.smallNumber(); + for (int i = 0; i < count; i++) { + partitionBy.add(Randomly.fromList(availableExpr)); + } + } + return partitionBy; + } + + private static WindowFrame generateWindowFrame(PostgresGlobalState globalState) { + if (Randomly.getBooleanWithRatherLowProbability()) { + WindowFrame.FrameType frameType = Randomly.fromOptions(WindowFrame.FrameType.values()); + PostgresExpression startExpr = generateFrameBound(globalState); + PostgresExpression endExpr = generateFrameBound(globalState); + return new WindowFrame(frameType, startExpr, endExpr); + } + return null; + } + + private static PostgresExpression generateFrameBound(PostgresGlobalState globalState) { + if (Randomly.getBooleanWithRatherLowProbability()) { + return generateCurrentRowBound(); + } else { + return generateOffsetBound(globalState); + } + } + + private static PostgresExpression generateCurrentRowBound() { + return PostgresConstant.createIntConstant(0); + } + + private static PostgresExpression generateOffsetBound(PostgresGlobalState globalState) { + return PostgresConstant.createIntConstant(globalState.getRandomly().getInteger()); + } + + private static PostgresDataType determineReturnType(String functionName) { + switch (functionName) { + case "percent_rank": + case "cume_dist": + return PostgresDataType.FLOAT; + default: + return PostgresDataType.INT; + } + } +} diff --git a/src/sqlancer/postgres/oracle/PostgresFuzzer.java b/src/sqlancer/postgres/oracle/PostgresFuzzer.java new file mode 100644 index 000000000..a0b46bd17 --- /dev/null +++ b/src/sqlancer/postgres/oracle/PostgresFuzzer.java @@ -0,0 +1,30 @@ +package sqlancer.postgres.oracle; + +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.PostgresVisitor; +import sqlancer.postgres.gen.PostgresRandomQueryGenerator; + +public class PostgresFuzzer implements TestOracle { + + private final PostgresGlobalState globalState; + + public PostgresFuzzer(PostgresGlobalState globalState) { + this.globalState = globalState; + } + + @Override + public void check() throws Exception { + String s = PostgresVisitor.asString( + PostgresRandomQueryGenerator.createRandomQuery(Randomly.smallNumber() + 1, globalState)) + ';'; + try { + globalState.executeStatement(new SQLQueryAdapter(s)); + globalState.getManager().incrementSelectQueryCount(); + } catch (Error e) { + + } + } + +} diff --git a/src/sqlancer/postgres/oracle/PostgresNoRECOracle.java b/src/sqlancer/postgres/oracle/PostgresNoRECOracle.java deleted file mode 100644 index 995eb207f..000000000 --- a/src/sqlancer/postgres/oracle/PostgresNoRECOracle.java +++ /dev/null @@ -1,168 +0,0 @@ -package sqlancer.postgres.oracle; - -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import sqlancer.IgnoreMeException; -import sqlancer.Randomly; -import sqlancer.common.oracle.NoRECBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.common.query.SQLQueryAdapter; -import sqlancer.common.query.SQLancerResultSet; -import sqlancer.postgres.PostgresCompoundDataType; -import sqlancer.postgres.PostgresGlobalState; -import sqlancer.postgres.PostgresSchema; -import sqlancer.postgres.PostgresSchema.PostgresColumn; -import sqlancer.postgres.PostgresSchema.PostgresDataType; -import sqlancer.postgres.PostgresSchema.PostgresTable; -import sqlancer.postgres.PostgresSchema.PostgresTables; -import sqlancer.postgres.PostgresVisitor; -import sqlancer.postgres.ast.PostgresCastOperation; -import sqlancer.postgres.ast.PostgresColumnValue; -import sqlancer.postgres.ast.PostgresExpression; -import sqlancer.postgres.ast.PostgresJoin; -import sqlancer.postgres.ast.PostgresJoin.PostgresJoinType; -import sqlancer.postgres.ast.PostgresPostfixText; -import sqlancer.postgres.ast.PostgresSelect; -import sqlancer.postgres.ast.PostgresSelect.PostgresFromTable; -import sqlancer.postgres.ast.PostgresSelect.PostgresSubquery; -import sqlancer.postgres.ast.PostgresSelect.SelectType; -import sqlancer.postgres.gen.PostgresCommon; -import sqlancer.postgres.gen.PostgresExpressionGenerator; -import sqlancer.postgres.oracle.tlp.PostgresTLPBase; - -public class PostgresNoRECOracle extends NoRECBase implements TestOracle { - - private final PostgresSchema s; - - public PostgresNoRECOracle(PostgresGlobalState globalState) { - super(globalState); - this.s = globalState.getSchema(); - PostgresCommon.addCommonExpressionErrors(errors); - PostgresCommon.addCommonFetchErrors(errors); - } - - @Override - public void check() throws SQLException { - PostgresTables randomTables = s.getRandomTableNonEmptyTables(); - List columns = randomTables.getColumns(); - PostgresExpression randomWhereCondition = getRandomWhereCondition(columns); - List tables = randomTables.getTables(); - - List joinStatements = getJoinStatements(state, columns, tables); - List fromTables = tables.stream().map(t -> new PostgresFromTable(t, Randomly.getBoolean())) - .collect(Collectors.toList()); - int secondCount = getUnoptimizedQueryCount(fromTables, randomWhereCondition, joinStatements); - int firstCount = getOptimizedQueryCount(fromTables, columns, randomWhereCondition, joinStatements); - if (firstCount == -1 || secondCount == -1) { - throw new IgnoreMeException(); - } - if (firstCount != secondCount) { - String queryFormatString = "-- %s;\n-- count: %d"; - String firstQueryStringWithCount = String.format(queryFormatString, optimizedQueryString, firstCount); - String secondQueryStringWithCount = String.format(queryFormatString, unoptimizedQueryString, secondCount); - state.getState().getLocalState() - .log(String.format("%s\n%s", firstQueryStringWithCount, secondQueryStringWithCount)); - String assertionMessage = String.format("the counts mismatch (%d and %d)!\n%s\n%s", firstCount, secondCount, - firstQueryStringWithCount, secondQueryStringWithCount); - throw new AssertionError(assertionMessage); - } - } - - public static List getJoinStatements(PostgresGlobalState globalState, List columns, - List tables) { - List joinStatements = new ArrayList<>(); - PostgresExpressionGenerator gen = new PostgresExpressionGenerator(globalState).setColumns(columns); - for (int i = 1; i < tables.size(); i++) { - PostgresExpression joinClause = gen.generateExpression(PostgresDataType.BOOLEAN); - PostgresTable table = Randomly.fromList(tables); - tables.remove(table); - PostgresJoinType options = PostgresJoinType.getRandom(); - PostgresJoin j = new PostgresJoin(new PostgresFromTable(table, Randomly.getBoolean()), joinClause, options); - joinStatements.add(j); - } - // JOIN subqueries - for (int i = 0; i < Randomly.smallNumber(); i++) { - PostgresTables subqueryTables = globalState.getSchema().getRandomTableNonEmptyTables(); - PostgresSubquery subquery = PostgresTLPBase.createSubquery(globalState, String.format("sub%d", i), - subqueryTables); - PostgresExpression joinClause = gen.generateExpression(PostgresDataType.BOOLEAN); - PostgresJoinType options = PostgresJoinType.getRandom(); - PostgresJoin j = new PostgresJoin(subquery, joinClause, options); - joinStatements.add(j); - } - return joinStatements; - } - - private PostgresExpression getRandomWhereCondition(List columns) { - return new PostgresExpressionGenerator(state).setColumns(columns).generateExpression(PostgresDataType.BOOLEAN); - } - - private int getUnoptimizedQueryCount(List fromTables, PostgresExpression randomWhereCondition, - List joinStatements) throws SQLException { - PostgresSelect select = new PostgresSelect(); - PostgresCastOperation isTrue = new PostgresCastOperation(randomWhereCondition, - PostgresCompoundDataType.create(PostgresDataType.INT)); - PostgresPostfixText asText = new PostgresPostfixText(isTrue, " as count", null, PostgresDataType.INT); - select.setFetchColumns(Arrays.asList(asText)); - select.setFromList(fromTables); - select.setSelectType(SelectType.ALL); - select.setJoinClauses(joinStatements); - int secondCount = 0; - unoptimizedQueryString = "SELECT SUM(count) FROM (" + PostgresVisitor.asString(select) + ") as res"; - if (options.logEachSelect()) { - logger.writeCurrent(unoptimizedQueryString); - } - errors.add("canceling statement due to statement timeout"); - SQLQueryAdapter q = new SQLQueryAdapter(unoptimizedQueryString, errors); - SQLancerResultSet rs; - try { - rs = q.executeAndGet(state); - } catch (Exception e) { - throw new AssertionError(unoptimizedQueryString, e); - } - if (rs == null) { - return -1; - } - if (rs.next()) { - secondCount += rs.getLong(1); - } - rs.close(); - return secondCount; - } - - private int getOptimizedQueryCount(List randomTables, List columns, - PostgresExpression randomWhereCondition, List joinStatements) throws SQLException { - PostgresSelect select = new PostgresSelect(); - PostgresColumnValue allColumns = new PostgresColumnValue(Randomly.fromList(columns), null); - select.setFetchColumns(Arrays.asList(allColumns)); - select.setFromList(randomTables); - select.setWhereClause(randomWhereCondition); - if (Randomly.getBooleanWithSmallProbability()) { - select.setOrderByExpressions(new PostgresExpressionGenerator(state).setColumns(columns).generateOrderBy()); - } - select.setSelectType(SelectType.ALL); - select.setJoinClauses(joinStatements); - int firstCount = 0; - try (Statement stat = con.createStatement()) { - optimizedQueryString = PostgresVisitor.asString(select); - if (options.logEachSelect()) { - logger.writeCurrent(optimizedQueryString); - } - try (ResultSet rs = stat.executeQuery(optimizedQueryString)) { - while (rs.next()) { - firstCount++; - } - } - } catch (SQLException e) { - throw new IgnoreMeException(); - } - return firstCount; - } - -} diff --git a/src/sqlancer/postgres/oracle/PostgresPivotedQuerySynthesisOracle.java b/src/sqlancer/postgres/oracle/PostgresPivotedQuerySynthesisOracle.java index 778a7ca90..9db5365a3 100644 --- a/src/sqlancer/postgres/oracle/PostgresPivotedQuerySynthesisOracle.java +++ b/src/sqlancer/postgres/oracle/PostgresPivotedQuerySynthesisOracle.java @@ -63,8 +63,8 @@ public SQLQueryAdapter getRectifiedQuery() throws SQLException { selectStatement.setOffsetClause(offsetClause); } List orderBy = new PostgresExpressionGenerator(globalState).setColumns(columns) - .generateOrderBy(); - selectStatement.setOrderByExpressions(orderBy); + .generateOrderBys(); + selectStatement.setOrderByClauses(orderBy); return new SQLQueryAdapter(PostgresVisitor.asString(selectStatement)); } diff --git a/src/sqlancer/postgres/oracle/tlp/PostgresTLPAggregateOracle.java b/src/sqlancer/postgres/oracle/tlp/PostgresTLPAggregateOracle.java index b9ea695e2..dba8c0a6d 100644 --- a/src/sqlancer/postgres/oracle/tlp/PostgresTLPAggregateOracle.java +++ b/src/sqlancer/postgres/oracle/tlp/PostgresTLPAggregateOracle.java @@ -29,7 +29,7 @@ import sqlancer.postgres.ast.PostgresSelect; import sqlancer.postgres.gen.PostgresCommon; -public class PostgresTLPAggregateOracle extends PostgresTLPBase implements TestOracle { +public class PostgresTLPAggregateOracle extends PostgresTLPBase implements TestOracle { private String firstResult; private String secondResult; @@ -61,7 +61,7 @@ protected void aggregateCheck() throws SQLException { } select.setFetchColumns(Arrays.asList(aggregate)); if (Randomly.getBooleanWithRatherLowProbability()) { - select.setOrderByExpressions(gen.generateOrderBy()); + select.setOrderByClauses(gen.generateOrderBys()); } originalQuery = PostgresVisitor.asString(select); firstResult = getAggregateResult(originalQuery); diff --git a/src/sqlancer/postgres/oracle/tlp/PostgresTLPBase.java b/src/sqlancer/postgres/oracle/tlp/PostgresTLPBase.java index 34f5e602a..55b5d2dcf 100644 --- a/src/sqlancer/postgres/oracle/tlp/PostgresTLPBase.java +++ b/src/sqlancer/postgres/oracle/tlp/PostgresTLPBase.java @@ -20,16 +20,16 @@ import sqlancer.postgres.ast.PostgresConstant; import sqlancer.postgres.ast.PostgresExpression; import sqlancer.postgres.ast.PostgresJoin; +import sqlancer.postgres.ast.PostgresJoin.PostgresJoinType; import sqlancer.postgres.ast.PostgresSelect; import sqlancer.postgres.ast.PostgresSelect.ForClause; import sqlancer.postgres.ast.PostgresSelect.PostgresFromTable; import sqlancer.postgres.ast.PostgresSelect.PostgresSubquery; import sqlancer.postgres.gen.PostgresCommon; import sqlancer.postgres.gen.PostgresExpressionGenerator; -import sqlancer.postgres.oracle.PostgresNoRECOracle; public class PostgresTLPBase extends TernaryLogicPartitioningOracleBase - implements TestOracle { + implements TestOracle { protected PostgresSchema s; protected PostgresTables targetTables; @@ -53,8 +53,27 @@ public void check() throws SQLException { protected List getJoinStatements(PostgresGlobalState globalState, List columns, List tables) { - return PostgresNoRECOracle.getJoinStatements(state, columns, tables); - // TODO joins + List joinStatements = new ArrayList<>(); + PostgresExpressionGenerator gen = new PostgresExpressionGenerator(globalState).setColumns(columns); + for (int i = 1; i < tables.size(); i++) { + PostgresExpression joinClause = gen.generateExpression(PostgresDataType.BOOLEAN); + PostgresTable table = Randomly.fromList(tables); + tables.remove(table); + PostgresJoinType options = PostgresJoinType.getRandom(); + PostgresJoin j = new PostgresJoin(new PostgresFromTable(table, Randomly.getBoolean()), joinClause, options); + joinStatements.add(j); + } + // JOIN subqueries + for (int i = 0; i < Randomly.smallNumber(); i++) { + PostgresTables subqueryTables = globalState.getSchema().getRandomTableNonEmptyTables(); + PostgresSubquery subquery = PostgresTLPBase.createSubquery(globalState, String.format("sub%d", i), + subqueryTables); + PostgresExpression joinClause = gen.generateExpression(PostgresDataType.BOOLEAN); + PostgresJoinType options = PostgresJoinType.getRandom(); + PostgresJoin j = new PostgresJoin(subquery, joinClause, options); + joinStatements.add(j); + } + return joinStatements; } protected void generateSelectBase(List tables, List joins) { @@ -103,7 +122,7 @@ public static PostgresSubquery createSubquery(PostgresGlobalState globalState, S select.setWhereClause(gen.generateExpression(0, PostgresDataType.BOOLEAN)); } if (Randomly.getBooleanWithRatherLowProbability()) { - select.setOrderByExpressions(gen.generateOrderBy()); + select.setOrderByClauses(gen.generateOrderBys()); } if (Randomly.getBoolean()) { select.setLimitClause(PostgresConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); diff --git a/src/sqlancer/postgres/oracle/tlp/PostgresTLPHavingOracle.java b/src/sqlancer/postgres/oracle/tlp/PostgresTLPHavingOracle.java index 3a1275299..2f520bebb 100644 --- a/src/sqlancer/postgres/oracle/tlp/PostgresTLPHavingOracle.java +++ b/src/sqlancer/postgres/oracle/tlp/PostgresTLPHavingOracle.java @@ -36,7 +36,7 @@ protected void havingCheck() throws SQLException { boolean orderBy = Randomly.getBoolean(); if (orderBy) { - select.setOrderByExpressions(gen.generateOrderBy()); + select.setOrderByClauses(gen.generateOrderBys()); } select.setHavingClause(predicate); String firstQueryString = PostgresVisitor.asString(select); diff --git a/src/sqlancer/postgres/oracle/tlp/PostgresTLPWhereOracle.java b/src/sqlancer/postgres/oracle/tlp/PostgresTLPWhereOracle.java deleted file mode 100644 index b043f6981..000000000 --- a/src/sqlancer/postgres/oracle/tlp/PostgresTLPWhereOracle.java +++ /dev/null @@ -1,45 +0,0 @@ -package sqlancer.postgres.oracle.tlp; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.Randomly; -import sqlancer.postgres.PostgresGlobalState; -import sqlancer.postgres.PostgresVisitor; - -public class PostgresTLPWhereOracle extends PostgresTLPBase { - - public PostgresTLPWhereOracle(PostgresGlobalState state) { - super(state); - } - - @Override - public void check() throws SQLException { - super.check(); - whereCheck(); - } - - protected void whereCheck() throws SQLException { - if (Randomly.getBooleanWithRatherLowProbability()) { - select.setOrderByExpressions(gen.generateOrderBy()); - } - String originalQueryString = PostgresVisitor.asString(select); - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - select.setOrderByExpressions(Collections.emptyList()); - select.setWhereClause(predicate); - String firstQueryString = PostgresVisitor.asString(select); - select.setWhereClause(negatedPredicate); - String secondQueryString = PostgresVisitor.asString(select); - select.setWhereClause(isNullPredicate); - String thirdQueryString = PostgresVisitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, Randomly.getBoolean(), state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state); - } -} diff --git a/src/sqlancer/presto/PrestoBugs.java b/src/sqlancer/presto/PrestoBugs.java new file mode 100644 index 000000000..b0eb3fe57 --- /dev/null +++ b/src/sqlancer/presto/PrestoBugs.java @@ -0,0 +1,20 @@ +package sqlancer.presto; + +public final class PrestoBugs { + + // https://github.com/prestodb/presto/issues/23324 + public static boolean bug23324 = true; + + // https://github.com/prestodb/presto/issues/23613 + public static boolean bug23613 = true; + + // https://github.com/prestodb/presto/issues/27608 + public static boolean bugVerifyError = true; + + // https://github.com/prestodb/presto/issues/27609 + public static boolean bugCompilerFailed = true; + + private PrestoBugs() { + } + +} diff --git a/src/sqlancer/presto/PrestoConstantUtils.java b/src/sqlancer/presto/PrestoConstantUtils.java new file mode 100644 index 000000000..6de744019 --- /dev/null +++ b/src/sqlancer/presto/PrestoConstantUtils.java @@ -0,0 +1,40 @@ +package sqlancer.presto; + +import java.math.BigDecimal; +import java.math.RoundingMode; + +public final class PrestoConstantUtils { + + private PrestoConstantUtils() { + } + + public static String removeNoneAscii(String str) { + return str.replaceAll("[^\\x00-\\x7F]", ""); + } + + public static String removeNonePrintable(String str) { // All Control Char + return str.replaceAll("[\\p{C}]", ""); + } + + public static String removeOthersControlChar(String str) { // Some Control Char + return str.replaceAll("[\\p{Cntrl}\\p{Cc}\\p{Cf}\\p{Co}\\p{Cn}]", ""); + } + + public static String removeAllControlChars(String str) { + return removeOthersControlChar(removeNonePrintable(str)).replaceAll("[\\r\\n\\t]", ""); + } + + public static BigDecimal getDecimal(double val, int scale, int precision) { + int part = precision - scale; + // long part + long lng = (long) val; + // decimal places + double d1 = val - lng; + String xStr = Long.toString(lng); + String substring = xStr.substring(xStr.length() - part); + long newX = substring.isEmpty() ? 0 : Long.parseLong(substring); + double finalD = newX + d1; + return new BigDecimal(finalD).setScale(scale, RoundingMode.CEILING); + } + +} diff --git a/src/sqlancer/presto/PrestoErrors.java b/src/sqlancer/presto/PrestoErrors.java new file mode 100644 index 000000000..dd2931976 --- /dev/null +++ b/src/sqlancer/presto/PrestoErrors.java @@ -0,0 +1,195 @@ +package sqlancer.presto; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class PrestoErrors { + + private PrestoErrors() { + } + + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.addAll(getFunctionErrors()); + + // Presto errors + errors.add("cannot be applied to"); + errors.add("LIKE expression must evaluate to a varchar"); + errors.add("JOIN ON clause must evaluate to a boolean"); + // errors.add("Unexpected parameters"); + + // SELECT SUM(count) FROM (SELECT + // CAST((-179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.0000 + // IS NOT NULL AND + // -179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.0000) + // AS BIGINT)as count FROM t0) as res + errors.add("Decimal overflow"); + errors.add("long overflow"); + errors.add("multiplication overflow"); + errors.add("addition overflow"); + errors.add("subtraction overflow"); + + // cast + // errors.add("Cannot cast"); + errors.add("Value cannot be cast to"); + errors.add("Cannot cast DECIMAL"); + errors.add("Cannot cast BIGINT"); + errors.add("Cannot cast INTEGER"); + + // TODO: check + errors.add("io.airlift.slice.Slice cannot be cast to java.lang.Number"); + errors.add("class io.airlift.slice.Slice cannot be cast to class java.lang.Number"); + if (PrestoBugs.bug23324) { + errors.add("Cannot cast java.lang.Long to io.airlift.slice.Slice"); + } + errors.add("Cannot cast java.lang.String to java.util.List"); + errors.add("Unexpected subquery expression in logical plan"); + if (PrestoBugs.bugVerifyError) { + errors.add("VerifyError"); + } + if (PrestoBugs.bugCompilerFailed) { + errors.add("Compiler failed"); + errors.add("Error processing class definition"); + } + + // 9223372036854775808 + errors.add("Invalid numeric literal"); + + errors.add("Division by zero"); + errors.add("/ by zero"); + + errors.add("Cannot subtract hour, minutes or seconds from a date"); + errors.add("Cannot add hour, minutes or seconds to a date"); + + errors.add("DECIMAL scale must be in range"); + errors.add("IN value and list items must be the same type"); + errors.add("is not a valid timestamp literal"); + errors.add("Unknown time-zone ID"); + errors.add("GROUP BY position"); + + // ARRAY + errors.add("Unknown type: ARRAY"); + + // SELECT + errors.add("WHERE clause must evaluate to a boolean"); + errors.add("HAVING clause must evaluate to a boolean"); + errors.add("not yet implemented"); + + errors.add("Value expression and result of subquery must be of the same type for quantified comparison"); + errors.add("All IN list values must be the same type"); + errors.add("All CASE results must be the same type"); + errors.add("Mismatched types"); + errors.add("CASE operand type does not match WHEN clause operand type"); + errors.add("Subquery result type must be orderable"); + errors.add("Escape character must be followed by '%', '_' or the escape character itself"); + errors.add("Types are not comparable with NULLIF"); + errors.add("not of the same type"); + + if (PrestoBugs.bug23613) { + errors.add("at index 1"); + } + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + private static List getRegexErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("missing ]"); + errors.add("missing )"); + errors.add("invalid escape sequence"); + errors.add("no argument for repetition operator: "); + errors.add("bad repetition operator"); + errors.add("trailing \\"); + errors.add("invalid perl operator"); + errors.add("invalid character class range"); + errors.add("width is not integer"); + + return errors; + } + + private static List getFunctionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("SUBSTRING cannot handle negative lengths"); + errors.add("is undefined outside [-1,1]"); // ACOS etc + errors.add("invalid type specifier"); // PRINTF + errors.add("argument index out of range"); // PRINTF + errors.add("invalid format string"); // PRINTF + errors.add("number is too big"); // PRINTF + errors.add("Like pattern must not end with escape character!"); // LIKE + errors.add("Could not choose a best candidate function for the function call \"date_part"); // date_part + errors.add("extract specifier"); // date_part + errors.add("not recognized"); // date_part + errors.add("not supported"); // date_part + errors.add("Failed to cast"); + errors.add("Conversion Error"); + errors.add("Could not cast value"); + errors.add("Insufficient padding in RPAD"); // RPAD + errors.add("Could not choose a best candidate function for the function call"); // monthname + errors.add("expected a numeric precision field"); // ROUND + errors.add("with non-constant precision is not supported"); // ROUND + errors.add("Unexpected parameters"); + errors.add("not registered"); + errors.add("Expected: least(E) E:orderable"); + errors.add("Expected: greatest(E) E:orderable"); + errors.add("Expected: max_by(V, K) K:orderable, V, max_by(V, K, bigint) V, K:orderable"); + errors.add("Expected: min_by(V, K) K:orderable, V, min_by(V, K, bigint) V, K:orderable"); + return errors; + } + + // TODO: cover presto error + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + + errors.addAll(getRegexErrors()); + errors.addAll(getExpressionErrors()); + + errors.add("NOT NULL constraint failed"); + errors.add("PRIMARY KEY or UNIQUE constraint violated"); + errors.add("duplicate key"); + errors.add("can't be cast because the value is out of range for the destination type"); + errors.add("Could not convert string"); + errors.add("Unimplemented type for cast"); + errors.add("field value out of range"); + errors.add("CHECK constraint failed"); + errors.add("Cannot explicitly insert values into rowid column"); // TODO: don't insert into rowid + errors.add(" Column with name rowid does not exist!"); // currently, there doesn't seem to way to determine if + // the table has a primary key + errors.add("Could not cast value"); + errors.add("create unique index, table contains duplicate data"); + errors.add("Failed to cast"); + + errors.add("Values rows have mismatched types"); + errors.add("Mismatch at column"); + errors.add("This connector does not support updates or deletes"); + errors.add("Values rows have mismatched types"); + errors.add("Invalid numeric literal"); + + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); + } + + public static List getGroupByErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("must be an aggregate expression or appear in GROUP BY clause"); + + return errors; + } + + public static void addGroupByErrors(ExpectedErrors errors) { + errors.addAll(getGroupByErrors()); + } + +} diff --git a/src/sqlancer/presto/PrestoGlobalState.java b/src/sqlancer/presto/PrestoGlobalState.java new file mode 100644 index 000000000..eb053712e --- /dev/null +++ b/src/sqlancer/presto/PrestoGlobalState.java @@ -0,0 +1,13 @@ +package sqlancer.presto; + +import java.sql.SQLException; + +import sqlancer.SQLGlobalState; + +public class PrestoGlobalState extends SQLGlobalState { + + @Override + protected PrestoSchema readSchema() throws SQLException { + return PrestoSchema.fromConnection(getConnection(), getDatabaseName()); + } +} diff --git a/src/sqlancer/presto/PrestoOptions.java b/src/sqlancer/presto/PrestoOptions.java new file mode 100644 index 000000000..f0c557ca5 --- /dev/null +++ b/src/sqlancer/presto/PrestoOptions.java @@ -0,0 +1,102 @@ +package sqlancer.presto; + +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; + +@Parameters(commandDescription = "Presto") +public class PrestoOptions implements DBMSSpecificOptions { + + public static final String DEFAULT_HOST = "localhost"; + public static final int DEFAULT_PORT = 8080; + + @Parameter(names = "--test-collate", arity = 1) + public boolean testCollate = true; + + @Parameter(names = "--test-check", description = "Allow generating CHECK constraints in tables", arity = 1) + public boolean testCheckConstraints = true; + + @Parameter(names = "--test-default-values", description = "Allow generating DEFAULT values in tables", arity = 1) + public boolean testDefaultValues = true; + + @Parameter(names = "--test-not-null", description = "Allow generating NOT NULL constraints in tables", arity = 1) + public boolean testNotNullConstraints = true; + + @Parameter(names = "--test-functions", description = "Allow generating functions in expressions", arity = 1) + public boolean testFunctions = true; + + @Parameter(names = "--test-casts", description = "Allow generating casts in expressions", arity = 1) + public boolean testCasts = true; + + @Parameter(names = "--test-between", description = "Allow generating the BETWEEN operator in expressions (FALSE by default : Presto null handling in BETWEEN operator : https://prestodb.io/docs/current/functions/comparison.html )", arity = 1) + public boolean testBetween; + + @Parameter(names = "--test-in", description = "Allow generating the IN operator in expressions", arity = 1) + public boolean testIn = true; + + @Parameter(names = "--test-case", description = "Allow generating the CASE operator in expressions", arity = 1) + public boolean testCase = true; + + @Parameter(names = "--test-binary-logicals", description = "Allow generating AND and OR in expressions", arity = 1) + public boolean testBinaryLogicals = true; + + @Parameter(names = "--test-int-constants", description = "Allow generating INTEGER constants", arity = 1) + public boolean testIntConstants = true; + + @Parameter(names = "--test-varchar-constants", description = "Allow generating VARCHAR constants", arity = 1) + public boolean testStringConstants = true; + + @Parameter(names = "--test-time-constants", description = "Allow generating DATE constants", arity = 1) + public boolean testDateConstants = true; + + @Parameter(names = "--test-date-constants", description = "Allow generating DATE constants", arity = 1) + public boolean testTimeConstants = true; + + @Parameter(names = "--test-timestamp-constants", description = "Allow generating TIMESTAMP constants", arity = 1) + public boolean testTimestampConstants = true; + + @Parameter(names = "--test-float-constants", description = "Allow generating floating-point constants", arity = 1) + public boolean testFloatConstants = true; + + @Parameter(names = "--test-boolean-constants", description = "Allow generating boolean constants", arity = 1) + public boolean testBooleanConstants = true; + + @Parameter(names = "--test-binary-comparisons", description = "Allow generating binary comparison operators (e.g., >= or LIKE)", arity = 1) + public boolean testBinaryComparisons = true; + + @Parameter(names = "--test-indexes", description = "Allow explicit (i.e. CREATE INDEX) and implicit (i.e., UNIQUE and PRIMARY KEY) indexes", arity = 1) + public boolean testIndexes = true; + + @Parameter(names = "--test-rowid", description = "Test tables' rowid columns", arity = 1) + public boolean testRowid = true; + + @Parameter(names = "--max-num-views", description = "The maximum number of views that can be generated for a database", arity = 1) + public int maxNumViews = 1; + + @Parameter(names = "--max-num-deletes", description = "The maximum number of DELETE statements that are issued for a database", arity = 1) + public int maxNumDeletes = 1; + + @Parameter(names = "--max-num-updates", description = "The maximum number of UPDATE statements that are issued for a database", arity = 1) + public int maxNumUpdates = 5; + + @Parameter(names = "--oracle") + public List oracles = List.of(PrestoOracleFactory.NOREC); + + @Parameter(names = "--catalog") + public String catalog = "memory"; + + @Parameter(names = "--schema") + public String schema = "test"; + + @Parameter(names = "--typed-generator", description = "the expression generator type - typed and untyped ") + public boolean typedGenerator = true; + + @Override + public List getTestOracleFactory() { + return oracles; + } + +} diff --git a/src/sqlancer/presto/PrestoOracleFactory.java b/src/sqlancer/presto/PrestoOracleFactory.java new file mode 100644 index 000000000..3076357f2 --- /dev/null +++ b/src/sqlancer/presto/PrestoOracleFactory.java @@ -0,0 +1,73 @@ +package sqlancer.presto; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.presto.gen.PrestoTypedExpressionGenerator; +import sqlancer.presto.test.PrestoQueryPartitioningAggregateTester; +import sqlancer.presto.test.PrestoQueryPartitioningDistinctTester; +import sqlancer.presto.test.PrestoQueryPartitioningGroupByTester; +import sqlancer.presto.test.PrestoQueryPartitioningHavingTester; +import sqlancer.presto.test.PrestoQueryPartitioningWhereTester; + +public enum PrestoOracleFactory implements OracleFactory { + NOREC { + @Override + public TestOracle create(PrestoGlobalState globalState) { + PrestoTypedExpressionGenerator gen = new PrestoTypedExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(PrestoErrors.getExpressionErrors()) + .with("canceling statement due to statement timeout").build(); + return new NoRECOracle<>(globalState, gen, errors); + } + + }, + HAVING { + @Override + public TestOracle create(PrestoGlobalState globalState) { + return new PrestoQueryPartitioningHavingTester(globalState); + } + }, + WHERE { + @Override + public TestOracle create(PrestoGlobalState globalState) { + return new PrestoQueryPartitioningWhereTester(globalState); + } + }, + GROUP_BY { + @Override + public TestOracle create(PrestoGlobalState globalState) { + return new PrestoQueryPartitioningGroupByTester(globalState); + } + }, + AGGREGATE { + @Override + public TestOracle create(PrestoGlobalState globalState) { + return new PrestoQueryPartitioningAggregateTester(globalState); + } + + }, + DISTINCT { + @Override + public TestOracle create(PrestoGlobalState globalState) { + return new PrestoQueryPartitioningDistinctTester(globalState); + } + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(PrestoGlobalState globalState) throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(WHERE.create(globalState)); + oracles.add(HAVING.create(globalState)); + oracles.add(AGGREGATE.create(globalState)); + oracles.add(DISTINCT.create(globalState)); + oracles.add(GROUP_BY.create(globalState)); + return new CompositeTestOracle<>(oracles, globalState); + } + } + +} diff --git a/src/sqlancer/presto/PrestoProvider.java b/src/sqlancer/presto/PrestoProvider.java new file mode 100644 index 000000000..be3ef4325 --- /dev/null +++ b/src/sqlancer/presto/PrestoProvider.java @@ -0,0 +1,195 @@ +package sqlancer.presto; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.presto.gen.PrestoInsertGenerator; +import sqlancer.presto.gen.PrestoTableGenerator; + +@AutoService(DatabaseProvider.class) +public class PrestoProvider extends SQLProviderAdapter { + + public PrestoProvider() { + super(PrestoGlobalState.class, PrestoOptions.class); + } + + // TODO : check actions based on connector + // returns number of actions + private static int mapActions(PrestoGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + if (Objects.requireNonNull(a) == Action.INSERT) { + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + // case UPDATE: + // return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumUpdates + 1); + // case EXPLAIN: + // return r.getInteger(0, 2); + // case DELETE: + // return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumDeletes + 1); + // case CREATE_VIEW: + // return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumViews + 1); + } + throw new AssertionError(a); + } + + @Override + public void generateDatabase(PrestoGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { + boolean success; + do { + SQLQueryAdapter qt = new PrestoTableGenerator().getQuery(globalState); + success = globalState.executeStatement(qt); + } while (!success); + } + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); // TODO + } + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + PrestoProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public SQLConnection createDatabase(PrestoGlobalState globalState) throws SQLException { + String username = globalState.getOptions().getUserName(); + String password = globalState.getOptions().getPassword(); + boolean useSSl = true; + if (globalState.getOptions().isDefaultUsername() && globalState.getOptions().isDefaultPassword()) { + username = "presto"; + password = null; + useSSl = false; + } + String host = globalState.getOptions().getHost(); + int port = globalState.getOptions().getPort(); + if (host == null) { + host = PrestoOptions.DEFAULT_HOST; + } + if (port == MainOptions.NO_SET_PORT) { + port = PrestoOptions.DEFAULT_PORT; + } + String catalogName = globalState.getDbmsSpecificOptions().catalog; + String databaseName = globalState.getDatabaseName(); + String url = String.format("jdbc:presto://%s:%d/%s?SSL=%b", host, port, catalogName, useSSl); + Connection con = DriverManager.getConnection(url, username, password); + List schemaNames = getSchemaNames(con, catalogName, databaseName); + dropExistingTables(con, catalogName, databaseName, schemaNames); + dropSchema(globalState, con, catalogName, databaseName); + createSchema(globalState, con, catalogName, databaseName); + useSchema(globalState, con, catalogName, databaseName); + return new SQLConnection(con); + + } + + private static void useSchema(PrestoGlobalState globalState, Connection con, String catalogName, + String databaseName) throws SQLException { + globalState.getState().logStatement("USE " + catalogName + "." + databaseName); + try (Statement s = con.createStatement()) { + s.execute("USE " + catalogName + "." + databaseName); + } + } + + private static void createSchema(PrestoGlobalState globalState, Connection con, String catalogName, + String databaseName) throws SQLException { + globalState.getState().logStatement("CREATE SCHEMA IF NOT EXISTS " + catalogName + "." + databaseName); + try (Statement s = con.createStatement()) { + s.execute("CREATE SCHEMA IF NOT EXISTS " + catalogName + "." + databaseName); + } + } + + private static void dropSchema(PrestoGlobalState globalState, Connection con, String catalogName, + String databaseName) throws SQLException { + globalState.getState().logStatement("DROP SCHEMA IF EXISTS " + catalogName + "." + databaseName); + try (Statement s = con.createStatement()) { + s.execute("DROP SCHEMA IF EXISTS " + catalogName + "." + databaseName); + } + } + + private static List getSchemaNames(Connection con, String catalogName, String databaseName) + throws SQLException { + List schemaNames = new ArrayList<>(); + final String showSchemasSql = "SHOW SCHEMAS FROM " + catalogName + " LIKE '" + databaseName + "'"; + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(showSchemasSql)) { + while (rs.next()) { + schemaNames.add(rs.getString("Schema")); + } + } + } + return schemaNames; + } + + private static void dropExistingTables(Connection con, String catalogName, String databaseName, + List schemaNames) throws SQLException { + if (!schemaNames.isEmpty()) { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery("SHOW TABLES FROM " + catalogName + "." + databaseName)) { + while (rs.next()) { + tableNames.add(rs.getString("Table")); + } + } + } + try (Statement s = con.createStatement()) { + for (String tableName : tableNames) { + s.execute("DROP TABLE IF EXISTS " + catalogName + "." + databaseName + "." + tableName); + } + } + } + } + + @Override + public String getDBMSName() { + return "presto"; + } + + public enum Action implements AbstractAction { + // SHOW_TABLES((g) -> new SQLQueryAdapter("SHOW TABLES", new ExpectedErrors(), false, false)), // + INSERT(PrestoInsertGenerator::getQuery); + // TODO : check actions based on connector + // DELETE(PrestoDeleteGenerator::generate), // + // UPDATE(PrestoUpdateGenerator::getQuery), // + // CREATE_VIEW(PrestoViewGenerator::generate), // + // EXPLAIN((g) -> { + // ExpectedErrors errors = new ExpectedErrors(); + // PrestoErrors.addExpressionErrors(errors); + // PrestoErrors.addGroupByErrors(errors); + // return new SQLQueryAdapter( + // "EXPLAIN " + PrestoToStringVisitor + // .asString(PrestoRandomQuerySynthesizer.generateSelect(g, Randomly.smallNumber() + 1)), + // errors); + // }); + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(PrestoGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + +} diff --git a/src/sqlancer/presto/PrestoSchema.java b/src/sqlancer/presto/PrestoSchema.java new file mode 100644 index 000000000..2e668969d --- /dev/null +++ b/src/sqlancer/presto/PrestoSchema.java @@ -0,0 +1,487 @@ +package sqlancer.presto; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; + +public class PrestoSchema extends AbstractSchema { + + public PrestoSchema(List databaseTables) { + super(databaseTables); + } + + public static PrestoSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + List databaseTables = new ArrayList<>(); + List tableNames = getTableNames(con); + for (String tableName : tableNames) { + List databaseColumns = getTableColumns(con, databaseName, tableName); + boolean isView = matchesViewName(tableName); + PrestoTable t = new PrestoTable(tableName, databaseColumns, isView); + for (PrestoColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + } + return new PrestoSchema(databaseTables); + } + + private static List getTableNames(SQLConnection con) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + // TODO: UPDATE + // SHOW TABLES [ FROM schema ] [ LIKE pattern [ ESCAPE 'escape_character' ] ] + try (ResultSet rs = s.executeQuery("SHOW TABLES")) { + while (rs.next()) { + tableNames.add(rs.getString("Table")); + } + } + } + return tableNames; + } + + private static List getTableColumns(SQLConnection con, String databaseName, String tableName) + throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String.format("select " + " table_catalog " + " , table_schema " + + " , table_name " + " , column_name " + " , is_nullable " + " , data_type " + + " from information_schema.columns " + " where table_schema = '%s' and table_name = '%s'", + databaseName, tableName))) { + while (rs.next()) { + String columnName = rs.getString("column_name"); + String dataType = rs.getString("data_type"); + boolean isNullable = rs.getString("is_nullable").contentEquals("YES"); + PrestoColumn c = new PrestoColumn(columnName, getColumnType(dataType), false, isNullable); + columns.add(c); + } + } + } + + return columns; + } + + private static PrestoCompositeDataType getColumnType(String typeString) { + int bracesStart = typeString.indexOf('('); + String type; + int size = 0; + int precision = 0; + if (bracesStart != -1) { + type = typeString.substring(0, bracesStart); + } else { + type = typeString; + } + type = type.toUpperCase(); + + PrestoDataType primitiveType; + switch (type) { + case "INTEGER": + primitiveType = PrestoDataType.INT; + size = 4; + break; + case "SMALLINT": + primitiveType = PrestoDataType.INT; + size = 2; + break; + case "BIGINT": + primitiveType = PrestoDataType.INT; + size = 8; + break; + case "TINYINT": + primitiveType = PrestoDataType.INT; + size = 1; + break; + case "VARCHAR": + primitiveType = PrestoDataType.VARCHAR; + break; + case "VARBINARY": + primitiveType = PrestoDataType.VARBINARY; + break; + case "CHAR": + primitiveType = PrestoDataType.CHAR; + break; + case "FLOAT": + case "REAL": + primitiveType = PrestoDataType.FLOAT; + size = 4; + break; + case "DOUBLE": + primitiveType = PrestoDataType.FLOAT; + size = 8; + break; + case "DECIMAL": + primitiveType = PrestoDataType.DECIMAL; + break; + case "BOOLEAN": + primitiveType = PrestoDataType.BOOLEAN; + break; + case "DATE": + primitiveType = PrestoDataType.DATE; + break; + case "TIME": + primitiveType = PrestoDataType.TIME; + break; + case "TIME WITH TIME ZONE": + primitiveType = PrestoDataType.TIME_WITH_TIME_ZONE; + break; + case "TIMESTAMP": + primitiveType = PrestoDataType.TIMESTAMP; + break; + case "TIMESTAMP WITH TIME ZONE": + primitiveType = PrestoDataType.TIMESTAMP_WITH_TIME_ZONE; + break; + case "INTERVAL DAY TO SECOND": + primitiveType = PrestoDataType.INTERVAL_DAY_TO_SECOND; + break; + case "INTERVAL YEAR TO MONTH": + primitiveType = PrestoDataType.INTERVAL_YEAR_TO_MONTH; + break; + case "JSON": + primitiveType = PrestoDataType.JSON; + break; + case "ARRAY": + int bracesEnd = typeString.length() - 1; + primitiveType = PrestoDataType.ARRAY; + PrestoCompositeDataType elementType = getColumnType(typeString.substring(bracesStart + 1, bracesEnd)); + return new PrestoCompositeDataType(primitiveType, elementType); + case "NULL": + primitiveType = PrestoDataType.NULL; + break; + default: + throw new AssertionError(typeString); + } + return new PrestoCompositeDataType(primitiveType, size, precision); + } + + public PrestoTables getRandomTableNonEmptyTables() { + return new PrestoTables(Randomly.nonEmptySubset(getDatabaseTables())); + } + + public enum PrestoDataType { + BOOLEAN, INT, FLOAT, DECIMAL, VARCHAR, CHAR, VARBINARY, JSON, DATE, TIME, TIMESTAMP, TIME_WITH_TIME_ZONE, + TIMESTAMP_WITH_TIME_ZONE, INTERVAL_YEAR_TO_MONTH, INTERVAL_DAY_TO_SECOND, ARRAY, + // MAP, + // ROW, + // IPADDRESS, + // UID, + // IPPREFIX, + // HyperLogLog, + // P4HyperLogLog, + // KHyperLogLog, + // QDigest, + // TDigest, + NULL; + + public static PrestoDataType getRandomWithoutNull() { + PrestoDataType dt; + do { + dt = Randomly.fromOptions(values()); + } while (dt == PrestoDataType.NULL); + return dt; + } + + public static List getNumericTypes() { + return Arrays.asList(INT, FLOAT, DECIMAL, DATE, TIME, TIMESTAMP, TIME_WITH_TIME_ZONE, + TIMESTAMP_WITH_TIME_ZONE); + } + + public static List getComparableTypes() { + return Arrays.asList(BOOLEAN, INT, FLOAT, DECIMAL, VARCHAR, CHAR, VARBINARY, JSON, DATE, TIME, TIMESTAMP, + TIME_WITH_TIME_ZONE, TIMESTAMP_WITH_TIME_ZONE, INTERVAL_YEAR_TO_MONTH, INTERVAL_DAY_TO_SECOND); + } + + public static List getOrderableTypes() { + return Arrays.asList(BOOLEAN, INT, FLOAT, DECIMAL, VARCHAR, CHAR, VARBINARY, + // JSON, + DATE, TIME, TIMESTAMP, TIME_WITH_TIME_ZONE, TIMESTAMP_WITH_TIME_ZONE, INTERVAL_YEAR_TO_MONTH, + INTERVAL_DAY_TO_SECOND, ARRAY); + } + + public static List getNumberTypes() { + return Arrays.asList(INT, FLOAT, DECIMAL); + } + + public static List getTemporalTypes() { + return Arrays.asList(DATE, TIME, TIMESTAMP, TIME_WITH_TIME_ZONE, TIMESTAMP_WITH_TIME_ZONE); + } + + public static List getIntervalTypes() { + return Arrays.asList(INTERVAL_YEAR_TO_MONTH, INTERVAL_DAY_TO_SECOND); + } + + public static List getTextTypes() { + return Arrays.asList(VARCHAR, CHAR, VARBINARY, JSON); + } + + public boolean isNumeric() { + switch (this) { + case INT: + case FLOAT: + case DECIMAL: + return true; + default: + return false; + } + } + + public boolean isOrderable() { + return getOrderableTypes().contains(this); + } + + public PrestoCompositeDataType get() { + return PrestoCompositeDataType.fromDataType(this); + } + } + + public static class PrestoCompositeDataType { + + private final PrestoDataType dataType; + + private final int size; + + private final int scale; + + private final PrestoCompositeDataType elementType; + + public PrestoCompositeDataType(PrestoDataType dataType, int dataSize, int dataScale) { + this.dataType = dataType; + this.size = dataSize; + this.scale = dataScale; + this.elementType = null; + } + + public PrestoCompositeDataType(PrestoDataType dataType, PrestoCompositeDataType elementType) { + if (dataType != PrestoDataType.ARRAY) { + throw new IllegalArgumentException(); + } + this.dataType = dataType; + this.size = -1; + this.scale = -1; + this.elementType = elementType; + } + + public static PrestoCompositeDataType getRandomWithoutNull() { + PrestoDataType type = PrestoDataType.getRandomWithoutNull(); + int size; + int scale = -1; + switch (type) { + case INT: + size = Randomly.fromOptions(1, 2, 4, 8); + break; + case FLOAT: + size = Randomly.fromOptions(4, 8); + break; + case DECIMAL: + size = Math.toIntExact(8); + scale = Math.toIntExact(4); + break; + case VARBINARY: + case JSON: + case VARCHAR: + case CHAR: + size = Math.toIntExact(Randomly.getNotCachedInteger(10, 250)); + break; + case ARRAY: + return new PrestoCompositeDataType(type, PrestoCompositeDataType.getRandomWithoutNull()); + case BOOLEAN: + case DATE: + case TIME: + case TIME_WITH_TIME_ZONE: + case TIMESTAMP: + case TIMESTAMP_WITH_TIME_ZONE: + case INTERVAL_DAY_TO_SECOND: + case INTERVAL_YEAR_TO_MONTH: + size = 0; + break; + default: + throw new AssertionError(type); + } + + return new PrestoCompositeDataType(type, size, scale); + } + + public static PrestoCompositeDataType fromDataType(PrestoDataType type) { + int size; + int scale = -1; + switch (type) { + case INT: + size = Randomly.fromOptions(1, 2, 4, 8); + break; + case FLOAT: + size = Randomly.fromOptions(4, 8); + break; + case DECIMAL: + size = Math.toIntExact(8); + scale = Math.toIntExact(4); + break; + case JSON: + case VARCHAR: + case CHAR: + size = Math.toIntExact(Randomly.getNotCachedInteger(10, 250)); + break; + case ARRAY: + return new PrestoCompositeDataType(type, PrestoCompositeDataType.getRandomWithoutNull()); + case BOOLEAN: + case VARBINARY: + case DATE: + case TIME: + case TIMESTAMP: + case TIMESTAMP_WITH_TIME_ZONE: + case TIME_WITH_TIME_ZONE: + case INTERVAL_DAY_TO_SECOND: + case INTERVAL_YEAR_TO_MONTH: + size = 0; + break; + default: + throw new AssertionError(type); + } + + return new PrestoCompositeDataType(type, size, scale); + } + + public PrestoDataType getPrimitiveDataType() { + return dataType; + } + + public int getSize() { + if (size == -1) { + throw new AssertionError(this); + } + return size; + } + + public int getScale() { + if (scale == -1) { + throw new AssertionError(this); + } + return scale; + } + + @Override + public String toString() { + switch (getPrimitiveDataType()) { + case INT: + switch (size) { + case 8: + return "BIGINT"; + case 4: + return "INTEGER"; + case 2: + return "SMALLINT"; + case 1: + return "TINYINT"; + default: + throw new AssertionError(size); + } + case VARBINARY: + return "VARBINARY"; + case JSON: + return "JSON"; + case VARCHAR: + return "VARCHAR" + "(" + size + ")"; + case CHAR: + return "CHAR" + "(" + size + ")"; + case FLOAT: + switch (size) { + case 4: + return "REAL"; + case 8: + return "DOUBLE"; + default: + throw new AssertionError(size); + } + case DECIMAL: + return "DECIMAL" + "(" + size + ", " + scale + ")"; + case BOOLEAN: + return "BOOLEAN"; + case TIMESTAMP_WITH_TIME_ZONE: + return "TIMESTAMP WITH TIME ZONE"; + case TIMESTAMP: + return "TIMESTAMP"; + case INTERVAL_YEAR_TO_MONTH: + return "INTERVAL YEAR TO MONTH"; + case INTERVAL_DAY_TO_SECOND: + return "INTERVAL DAY TO SECOND"; + case DATE: + return "DATE"; + case TIME: + return "TIME"; + case TIME_WITH_TIME_ZONE: + return "TIME WITH TIME ZONE"; + case ARRAY: + return "ARRAY(" + elementType + ")"; + case NULL: + return "NULL"; + default: + throw new AssertionError(getPrimitiveDataType()); + } + } + + public PrestoCompositeDataType getElementType() { + return elementType; + } + + public boolean isOrderable() { + if (dataType == PrestoDataType.ARRAY) { + assert elementType != null; + return elementType.isOrderable(); + } + return dataType.isOrderable(); + } + + } + + public static class PrestoColumn extends AbstractTableColumn { + + private final boolean isPrimaryKey; + private final boolean isNullable; + + public PrestoColumn(String name, PrestoCompositeDataType columnType, boolean isPrimaryKey, boolean isNullable) { + super(name, null, columnType); + this.isPrimaryKey = isPrimaryKey; + this.isNullable = isNullable; + } + + @Override + public boolean isPrimaryKey() { + return isPrimaryKey; + } + + public boolean isNullable() { + return isNullable; + } + + public boolean isOrderable() { + return getType().getPrimitiveDataType().isOrderable(); + } + + } + + public static class PrestoTables extends AbstractTables { + + public PrestoTables(List tables) { + super(tables); + } + + } + + public static class PrestoTable extends AbstractRelationalTable { + + public PrestoTable(String tableName, List columns, boolean isView) { + super(tableName, columns, Collections.emptyList(), isView); + } + + } + +} diff --git a/src/sqlancer/presto/PrestoToStringVisitor.java b/src/sqlancer/presto/PrestoToStringVisitor.java new file mode 100644 index 000000000..11c6d5fa6 --- /dev/null +++ b/src/sqlancer/presto/PrestoToStringVisitor.java @@ -0,0 +1,148 @@ +package sqlancer.presto; + +import sqlancer.common.ast.newast.NewToStringVisitor; +import sqlancer.presto.ast.PrestoAtTimeZoneOperator; +import sqlancer.presto.ast.PrestoCastFunction; +import sqlancer.presto.ast.PrestoConstant; +import sqlancer.presto.ast.PrestoExpression; +import sqlancer.presto.ast.PrestoFunctionWithoutParenthesis; +import sqlancer.presto.ast.PrestoJoin; +import sqlancer.presto.ast.PrestoMultiValuedComparison; +import sqlancer.presto.ast.PrestoQuantifiedComparison; +import sqlancer.presto.ast.PrestoSelect; + +public class PrestoToStringVisitor extends NewToStringVisitor { + + public static String asString(PrestoExpression expr) { + PrestoToStringVisitor visitor = new PrestoToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } + + @Override + public void visitSpecific(PrestoExpression expr) { + if (expr instanceof PrestoConstant) { + visit((PrestoConstant) expr); + } else if (expr instanceof PrestoSelect) { + visit((PrestoSelect) expr); + } else if (expr instanceof PrestoJoin) { + visit((PrestoJoin) expr); + } else if (expr instanceof PrestoCastFunction) { + visit((PrestoCastFunction) expr); + } else if (expr instanceof PrestoFunctionWithoutParenthesis) { + visit((PrestoFunctionWithoutParenthesis) expr); + } else if (expr instanceof PrestoAtTimeZoneOperator) { + visit((PrestoAtTimeZoneOperator) expr); + } else if (expr instanceof PrestoMultiValuedComparison) { + visit((PrestoMultiValuedComparison) expr); + } else if (expr instanceof PrestoQuantifiedComparison) { + visit((PrestoQuantifiedComparison) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + private void visit(PrestoJoin join) { + visit((PrestoExpression) join.getLeftTable()); + sb.append(" "); + sb.append(join.getJoinType()); + sb.append(" "); + if (join.getOuterType() != null) { + sb.append(join.getOuterType()); + } + sb.append(" JOIN "); + visit((PrestoExpression) join.getRightTable()); + if (join.getOnCondition() != null) { + sb.append(" ON "); + visit(join.getOnCondition()); + } + } + + private void visit(PrestoConstant constant) { + sb.append(constant.toString()); + } + + private void visit(PrestoAtTimeZoneOperator timeZoneOperator) { + visit(timeZoneOperator.getExpr()); + sb.append(" AT TIME ZONE "); + sb.append(timeZoneOperator.getTimeZone()); + } + + private void visit(PrestoFunctionWithoutParenthesis prestoFunctionWithoutParenthesis) { + sb.append(prestoFunctionWithoutParenthesis.getExpr()); + } + + private void visit(PrestoSelect select) { + sb.append("SELECT "); + if (select.isDistinct()) { + sb.append("DISTINCT "); + } + visit(select.getFetchColumns()); + sb.append(" FROM "); + visit(select.getFromList()); + if (!select.getFromList().isEmpty() && !select.getJoinList().isEmpty()) { + sb.append(", "); + } + if (!select.getJoinList().isEmpty()) { + visit(select.getJoinList()); + } + if (select.getWhereClause() != null) { + sb.append(" WHERE "); + visit(select.getWhereClause()); + } + if (!select.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(select.getGroupByExpressions()); + } + if (select.getHavingClause() != null) { + sb.append(" HAVING "); + visit(select.getHavingClause()); + } + if (!select.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(select.getOrderByClauses()); + } + if (select.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(select.getLimitClause()); + } + if (select.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(select.getOffsetClause()); + } + } + + public void visit(PrestoCastFunction cast) { + sb.append("CAST(("); + visit(cast.getExpr()); + sb.append(") AS "); + sb.append(cast.getType().toString()); + sb.append(")"); + } + + public void visit(PrestoMultiValuedComparison comp) { + sb.append("("); + visit(comp.getLeft()); + sb.append(" "); + sb.append(comp.getOp().getStringRepresentation()); + sb.append(" "); + sb.append(comp.getType()); + sb.append(" (VALUES "); + visit(comp.getRight()); + sb.append(")"); + sb.append(")"); + } + + public void visit(PrestoQuantifiedComparison comp) { + sb.append("("); + visit(comp.getLeft()); + sb.append(" "); + sb.append(comp.getOp().getStringRepresentation()); + sb.append(" "); + sb.append(comp.getType()); + sb.append(" ( "); + visit(comp.getRight()); + sb.append(" ) "); + sb.append(")"); + } +} diff --git a/src/sqlancer/presto/ast/PrestoAggregateFunction.java b/src/sqlancer/presto/ast/PrestoAggregateFunction.java new file mode 100644 index 000000000..25840933a --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoAggregateFunction.java @@ -0,0 +1,713 @@ +package sqlancer.presto.ast; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import sqlancer.Randomly; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoCompositeDataType; +import sqlancer.presto.PrestoSchema.PrestoDataType; +import sqlancer.presto.gen.PrestoTypedExpressionGenerator; + +public enum PrestoAggregateFunction implements PrestoFunction { + + // General Aggregate Functions + + // arbitrary(x) → [same as input] + // Returns an arbitrary non-null value of x, if one exists. + ARBITRARY("arbitrary", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return true; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { returnType.getPrimitiveDataType() }; + } + + @Override + public PrestoDataType getReturnType() { + return PrestoDataType.getRandomWithoutNull(); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType.fromDataType(getReturnType())); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, orderable); + } + + }, + + // TODO: + // + // array_agg(x) → array<[same as input]># + // Returns an array created from the input x elements. + + // avg(x) → double + // Returns the average (arithmetic mean) of all input values. + AVG("avg", PrestoDataType.FLOAT) { + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { + Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT, PrestoDataType.DECIMAL) }; + } + }, + // avg(time interval type) → time interval type# + // Returns the average interval length of all input values. + AVG_INTERVAL_YM("avg", PrestoDataType.INTERVAL_YEAR_TO_MONTH, PrestoDataType.INTERVAL_YEAR_TO_MONTH), + AVG_INTERVAL_DS("avg", PrestoDataType.INTERVAL_DAY_TO_SECOND, PrestoDataType.INTERVAL_DAY_TO_SECOND), + + // bool_and(boolean) → boolean# + // Returns TRUE if every input value is TRUE, otherwise FALSE. + BOOL_AND("bool_and", PrestoDataType.BOOLEAN, PrestoDataType.BOOLEAN), + // bool_or(boolean) → boolean# + // Returns TRUE if any input value is TRUE, otherwise FALSE. + BOOL_OR("bool_or", PrestoDataType.BOOLEAN, PrestoDataType.BOOLEAN), + // checksum(x) → varbinary# + // Returns an order-insensitive checksum of the given values. + CHECKSUM("checksum", PrestoDataType.VARBINARY) { + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { Randomly.fromList(PrestoDataType.getComparableTypes()) }; + } + }, + + // count(*) → bigint# + // Returns the number of input rows. + COUNT_ALL("count(*)", PrestoDataType.INT), + // count(x) → bigint# + // Returns the number of non-null input values. + COUNT_NOARGS("count", PrestoDataType.INT), COUNT("count", PrestoDataType.INT) { + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { Randomly.fromOptions(PrestoDataType.getRandomWithoutNull()) }; + } + }, + // count_if(x) → bigint# + // Returns the number of TRUE input values. This function is equivalent to count(CASE WHEN x THEN 1 END). + COUNT_IF("count_if", PrestoDataType.INT) { + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { Randomly.fromOptions(PrestoDataType.getRandomWithoutNull()) }; + } + }, + // every(boolean) → boolean# + // This is an alias for bool_and(). + EVERY("every", PrestoDataType.BOOLEAN, PrestoDataType.BOOLEAN), + // geometric_mean(x) → double# + // Returns the geometric mean of all input values. + GEOMETRIC_MEAN("geometric_mean", PrestoDataType.FLOAT) { + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { + Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT, PrestoDataType.DECIMAL) }; + } + }, + // max_by(x, y) → [same as x]# + // Returns the value of x associated with the maximum value of y over all input values. + MAX_BY("max_by", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return true; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { returnType.getPrimitiveDataType(), + Randomly.fromList(PrestoDataType.getOrderableTypes()) }; + } + + @Override + public PrestoDataType getReturnType() { + return Randomly.fromList(PrestoDataType.getOrderableTypes()); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType.fromDataType(getReturnType())); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, orderable); + } + + }, + + // TODO: + // + // max_by(x, y, n) → array<[same as x]># + // Returns n values of x associated with the n largest of all input values of y in descending order of y. + + // min_by(x, y) → [same as x]# + // Returns the value of x associated with the minimum value of y over all input values. + MIN_BY("min_by", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return true; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { returnType.getPrimitiveDataType(), + Randomly.fromList(PrestoDataType.getOrderableTypes()) }; + } + + @Override + public PrestoDataType getReturnType() { + return Randomly.fromList(PrestoDataType.getOrderableTypes()); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType.fromDataType(getReturnType())); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, orderable); + } + + }, + // TODO: + // + // min_by(x, y, n) → array<[same as x]> + // Returns n values of x associated with the n smallest of all input values of y in ascending order of y. + + // max(x) → [same as input] + // Returns the maximum value of all input values. + MAX("max", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + boolean isCompatible = PrestoDataType.getOrderableTypes().contains(returnType.getPrimitiveDataType()); + if (returnType.getPrimitiveDataType() == PrestoDataType.ARRAY && returnType.toString().contains("JSON")) { + isCompatible = false; + } + return isCompatible; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { returnType.getPrimitiveDataType() }; + } + + @Override + public PrestoDataType getReturnType() { + return Randomly.fromList(PrestoDataType.getOrderableTypes()); + } + + @Override + public PrestoCompositeDataType getCompositeReturnType() { + PrestoDataType dataType = Randomly.fromList(PrestoDataType.getOrderableTypes()); + PrestoCompositeDataType returnType; + do { + returnType = PrestoCompositeDataType.fromDataType(dataType); + } while (!isCompatibleWithReturnType(returnType)); + return returnType; + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType.fromDataType(getReturnType())); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, true); + } + + }, + + // TODO: + // + // max(x, n) → array<[same as x]># + // Returns n largest values of all input values of x. + + // min(x) → [same as input]# + // Returns the minimum value of all input values. + MIN("min", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + boolean orderable = PrestoDataType.getOrderableTypes().contains(returnType.getPrimitiveDataType()); + if (returnType.getPrimitiveDataType() == PrestoDataType.ARRAY && returnType.toString().contains("JSON")) { + orderable = false; + } + return orderable; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { returnType.getPrimitiveDataType() }; + } + + @Override + public PrestoDataType getReturnType() { + return Randomly.fromList(PrestoDataType.getOrderableTypes()); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType.fromDataType(getReturnType())); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, orderable); + } + + }, + + // TODO: + // + // min(x, n) → array<[same as x]># + // Returns n smallest values of all input values of x. + + // TODO: + // + // reduce_agg(inputValue T, initialState S, inputFunction(S, T, S), combineFunction(S, S, S)) → S# + // Reduces all input values into a single value. inputFunction will be invoked for each input value. In addition to + // taking the input value, inputFunction takes the current state, initially initialState, and returns the new state. + // combineFunction will be invoked to combine two states into a new state. The final state is returned: + // + // SELECT id, reduce_agg(value, (a, b) -> a + b, (a, b) -> a + b) + // FROM ( + // VALUES + // (1, 2), + // (1, 3), + // (1, 4), + // (2, 20), + // (2, 30), + // (2, 40) + // ) AS t(id, value) + // GROUP BY id; + // -- (1, 9) + // -- (2, 90) + // + // SELECT id, reduce_agg(value, (a, b) -> a * b, (a, b) -> a * b) + // FROM ( + // VALUES + // (1, 2), + // (1, 3), + // (1, 4), + // (2, 20), + // (2, 30), + // (2, 40) + // ) AS t(id, value) + // GROUP BY id; + // -- (1, 24) + // -- (2, 24000) + // The state type must be a boolean, integer, floating-point, or date/time/interval. + + // TODO: + // + // set_agg(x) → array<[same as input]># + // Returns an array created from the distinct input x elements. + + // TODO: + // + // set_union(array(T)) -> array(T)# + // Returns an array of all the distinct values contained in each array of the input + // + // Example: + // + // SELECT set_union(elements) + // FROM ( + // VALUES + // ARRAY[1, 3], + // ARRAY[2, 4] + // ) AS t(elements); + // Returns ARRAY[1, 3, 4] + + // sum(x) → [same as input]# + // Returns the sum of all input values. + SUM("sum", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return List.of(PrestoDataType.INT, PrestoDataType.FLOAT, PrestoDataType.DECIMAL) + .contains(returnType.getPrimitiveDataType()); + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { returnType.getPrimitiveDataType() }; + } + + @Override + public PrestoDataType getReturnType() { + return Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT, PrestoDataType.DECIMAL); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType.fromDataType(getReturnType())); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, orderable); + } + }, + // sum(time interval type) → time interval type# + // Returns the average interval length of all input values. + SUM_INTERVAL_YM("sum", PrestoDataType.INTERVAL_YEAR_TO_MONTH, PrestoDataType.INTERVAL_YEAR_TO_MONTH), + SUM_INTERVAL_DS("sum", PrestoDataType.INTERVAL_DAY_TO_SECOND, PrestoDataType.INTERVAL_DAY_TO_SECOND), + + // Bitwise Aggregate Functions# + + // bitwise_and_agg(x) → bigint# + // Returns the bitwise AND of all input values in 2’s complement representation. + BITWISE_AND_AGG("bitwise_and_agg", PrestoDataType.INT, PrestoDataType.INT), + + // bitwise_or_agg(x) → bigint# + // Returns the bitwise OR of all input values in 2’s complement representation. + BITWISE_OR_AGG("bitwise_or_agg", PrestoDataType.INT, PrestoDataType.INT), + + // TODO: + // + // Map Aggregate Functions + + // histogram(x)# + // Returns a map containing the count of the number of times each input value occurs. + // + // map_agg(key, value)# + // Returns a map created from the input key / value pairs. + // + // map_union(x(K, V)) -> map(K, V)# + // Returns the union of all the input maps. If a key is found in multiple input maps, that key’s value in the + // resulting map comes from an arbitrary input map. + // + // map_union_sum(x(K, V)) -> map(K, V)# + // Returns the union of all the input maps summing the values of matching keys in all the maps. All null values in + // the original maps are coalesced to 0. + // + // multimap_agg(key, value)# + // Returns a multimap created from the input key / value pairs. Each key can be associated with multiple values. + + // Approximate Aggregate Functions# + // approx_distinct(x) → bigint# + // Returns the approximate number of distinct input values. This function provides an approximation of + // count(DISTINCT x). + // Zero is returned if all input values are null. + // This function should produce a standard error of 2.3%, which is the standard deviation of the (approximately + // normal) + // error distribution over all possible sets. It does not guarantee an upper bound on the error for any specific + // input set. + APPROX_DISTINCT("approx_distinct", PrestoDataType.INT) { + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { Randomly.fromList(PrestoDataType.getOrderableTypes()) }; + } + }, + // + // approx_distinct(x, e) → bigint# + // Returns the approximate number of distinct input values. This function provides an approximation of + // count(DISTINCT x). Zero is returned if all input values are null. + // + // This function should produce a standard error of no more than e, which is the standard deviation of the + // (approximately normal) error distribution over all possible sets. It does not guarantee an upper bound on the + // error for any specific input set. The current implementation of this function requires that e be in the range of + // [0.0040625, 0.26000]. + APPROX_DISTINCT_2("approx_distinct", PrestoDataType.INT) { + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { Randomly.fromList(PrestoDataType.getOrderableTypes()), PrestoDataType.FLOAT }; + } + }, + // approx_percentile(x, percentage) → [same as x]# + // Returns the approximate percentile for all input values of x at the given percentage. + // The value of percentage must be between zero and one and must be constant for all input rows. + APPROX_PERCENTILE("approx_percentile", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return List.of(PrestoDataType.INT, PrestoDataType.FLOAT).contains(returnType.getPrimitiveDataType()); + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT), + PrestoDataType.FLOAT }; + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoDataType[] argumentTypes2, PrestoCompositeDataType returnType2) { + List arguments = new ArrayList<>(); + arguments.add(gen.generateExpression(returnType2, depth + 1)); + arguments.add(new PrestoConstant.PrestoFloatConstant(Randomly.getPercentage())); + return arguments; + } + + @Override + public PrestoDataType getReturnType() { + return Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType.fromDataType(getReturnType())); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, orderable); + } + }, + + // approx_percentile(x, percentage, accuracy) → [same as x]# + // As approx_percentile(x, percentage), but with a maximum rank error of accuracy. + // The value of accuracy must be between zero and one (exclusive) and must be constant for all input rows. + // Note that a lower “accuracy” is really a lower error threshold, and thus more accurate. The default accuracy is + // 0.01. + APPROX_PERCENTILE_ACCURACY("approx_percentile", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return List.of(PrestoDataType.INT, PrestoDataType.FLOAT).contains(returnType.getPrimitiveDataType()); + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT), + PrestoDataType.FLOAT, PrestoDataType.FLOAT }; + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoDataType[] argumentTypes2, PrestoCompositeDataType returnType2) { + List arguments = new ArrayList<>(); + arguments.add(gen.generateExpression(returnType2, depth + 1)); + arguments.add(new PrestoConstant.PrestoFloatConstant(Randomly.getPercentage())); + if (Randomly.getBooleanWithRatherLowProbability()) { + arguments.add(new PrestoConstant.PrestoFloatConstant(0.01D)); + } else { + arguments.add(new PrestoConstant.PrestoFloatConstant(Randomly.getPercentage())); + } + return arguments; + } + + @Override + public PrestoDataType getReturnType() { + return Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType + .fromDataType(Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT))); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, orderable); + } + }, + + // TODO: + // + // approx_percentile(x, percentages) → array<[same as x]># + // Returns the approximate percentile for all input values of x at each of the specified percentages. Each element + // of the percentages array must be between zero and one, and the array must be constant for all input rows. + // + // approx_percentile(x, percentages, accuracy) → array<[same as x]># + // As approx_percentile(x, percentages), but with a maximum rank error of accuracy. + + // approx_percentile(x, w, percentage) → [same as x]# + // Returns the approximate weighed percentile for all input values of x using the per-item weight w at the + // percentage p. + // The weight must be an integer value of at least one. + // It is effectively a replication count for the value x in the percentile set. + // The value of p must be between zero and one and must be constant for all input rows. + APPROX_PERCENTILE_WEIGHT("approx_percentile", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return List.of(PrestoDataType.INT, PrestoDataType.FLOAT).contains(returnType.getPrimitiveDataType()); + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT), + PrestoDataType.INT, PrestoDataType.FLOAT }; + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoDataType[] argumentTypes2, PrestoCompositeDataType returnType2) { + List arguments = new ArrayList<>(); + arguments.add(gen.generateExpression(returnType2, depth + 1)); + arguments.add(new PrestoConstant.PrestoFloatConstant(Randomly.getPercentage())); + if (Randomly.getBooleanWithRatherLowProbability()) { + arguments.add(new PrestoConstant.PrestoIntConstant(1)); + } else { + arguments.add(new PrestoConstant.PrestoIntConstant(Randomly.smallNumber())); + } + return arguments; + } + + @Override + public PrestoDataType getReturnType() { + return Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType + .fromDataType(Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT))); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, orderable); + } + }, + + // approx_percentile(x, w, percentage, accuracy) → [same as x]# + // As approx_percentile(x, w, percentage), but with a maximum rank error of accuracy. + APPROX_PERCENTILE_PERCENTAGE_ACCURACY("approx_percentile", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return List.of(PrestoDataType.INT, PrestoDataType.FLOAT).contains(returnType.getPrimitiveDataType()); + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT), + PrestoDataType.INT, PrestoDataType.FLOAT, PrestoDataType.FLOAT }; + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoDataType[] argumentTypes2, PrestoCompositeDataType returnType2) { + List arguments = new ArrayList<>(); + arguments.add(gen.generateExpression(returnType2, depth + 1)); + if (Randomly.getBooleanWithRatherLowProbability()) { + arguments.add(new PrestoConstant.PrestoIntConstant(1)); + } else { + arguments.add(new PrestoConstant.PrestoIntConstant(Randomly.smallNumber())); + } + arguments.add(new PrestoConstant.PrestoFloatConstant(Randomly.getPercentage())); + if (Randomly.getBooleanWithRatherLowProbability()) { + arguments.add(new PrestoConstant.PrestoFloatConstant(0.01D)); + } else { + arguments.add(new PrestoConstant.PrestoFloatConstant(Randomly.getPercentage())); + } + return arguments; + } + + @Override + public PrestoDataType getReturnType() { + return Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoCompositeDataType returnType, boolean orderable) { + PrestoCompositeDataType returnTypeLocal = Objects.requireNonNullElseGet(returnType, + () -> PrestoCompositeDataType + .fromDataType(Randomly.fromOptions(PrestoDataType.INT, PrestoDataType.FLOAT))); + return super.getArgumentsForReturnType(gen, depth, returnTypeLocal, orderable); + } + }; + + // TODO: + // + // approx_percentile(x, w, percentages) → array<[same as x]># + // Returns the approximate weighed percentile for all input values of x using the per-item weight w at each of the + // given percentages specified in the array. The weight must be an integer value of at least one. It is effectively + // a replication count for the value x in the percentile set. Each element of the array must be between zero and + // one, and the array must be constant for all input rows. + // + // approx_percentile(x, w, percentages, accuracy) → array<[same as x]># + // As approx_percentile(x, w, percentages), but with a maximum rank error of accuracy. + // + // approx_set(x) → HyperLogLog + // See HyperLogLog Functions. + // + // merge(x) → HyperLogLog + // See HyperLogLog Functions. + // + // khyperloglog_agg(x) → KHyperLogLog + // See KHyperLogLog Functions. + + // TODO: + // + // merge(qdigest(T)) -> qdigest(T) + // See Quantile Digest Functions. + // + // qdigest_agg(x) → qdigest<[same as x]> + // See Quantile Digest Functions. + // + // qdigest_agg(x, w) → qdigest<[same as x]> + // See Quantile Digest Functions. + // + // qdigest_agg(x, w, accuracy) → qdigest<[same as x]> + // See Quantile Digest Functions. + // + // numeric_histogram(buckets, value, weight) → map# + // Computes an approximate histogram with up to buckets number of buckets for all values with a per-item weight of + // weight. + // The keys of the returned map are roughly the center of the bin, and the entry is the total weight of the bin. + // The algorithm is based loosely on [BenHaimTomTov2010]. + // + // buckets must be a bigint. value and weight must be numeric. + // + // numeric_histogram(buckets, value) → map# + // Computes an approximate histogram with up to buckets number of buckets for all values. This function is + // equivalent to the variant of numeric_histogram() that takes a weight, with a per-item weight of 1. In this case, + // the total weight in the returned map is the count of items in the bin. + + private final PrestoDataType returnType; + private final PrestoDataType[] argumentTypes; + private final String functionName; + + PrestoAggregateFunction(String functionName, PrestoDataType returnType) { + this.functionName = functionName; + this.returnType = returnType; + this.argumentTypes = new PrestoDataType[0]; + } + + PrestoAggregateFunction(String functionName, PrestoDataType returnType, PrestoDataType... argumentTypes) { + this.functionName = functionName; + this.returnType = returnType; + this.argumentTypes = argumentTypes.clone(); + } + + public static PrestoAggregateFunction getRandomMetamorphicOracle() { + return Randomly.fromOptions(ARBITRARY, AVG, AVG_INTERVAL_YM, AVG_INTERVAL_DS, BOOL_AND, BOOL_OR, CHECKSUM, + COUNT_ALL, COUNT_NOARGS, COUNT, COUNT_IF, EVERY, GEOMETRIC_MEAN, MAX_BY, MIN_BY, MAX, MIN, SUM, + SUM_INTERVAL_YM, SUM_INTERVAL_DS, BITWISE_AND_AGG, BITWISE_OR_AGG); + } + + public static PrestoAggregateFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public static List getFunctionsCompatibleWith(PrestoCompositeDataType returnType) { + return Stream.of(values()).filter(f -> f.isCompatibleWithReturnType(returnType)).collect(Collectors.toList()); + } + + @Override + public String getFunctionName() { + return functionName; + } + + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return this.returnType == returnType.getPrimitiveDataType(); + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return argumentTypes.clone(); + } + + @Override + public int getNumberOfArguments() { + return 1; + } + + public List getReturnTypes(PrestoSchema.PrestoDataType dataType) { + return Collections.singletonList(dataType); + } + + public PrestoDataType getReturnType() { + if (returnType == null) { + return PrestoDataType.getRandomWithoutNull(); + } + return returnType; + } + + public PrestoCompositeDataType getCompositeReturnType() { + PrestoDataType dataType = getReturnType(); + return PrestoCompositeDataType.fromDataType(dataType); + } +} diff --git a/src/sqlancer/presto/ast/PrestoAlias.java b/src/sqlancer/presto/ast/PrestoAlias.java new file mode 100644 index 000000000..b327bed11 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoAlias.java @@ -0,0 +1,9 @@ +package sqlancer.presto.ast; + +import sqlancer.common.ast.newast.NewAliasNode; + +public class PrestoAlias extends NewAliasNode implements PrestoExpression { + public PrestoAlias(PrestoExpression expr, String alias) { + super(expr, alias); + } +} diff --git a/src/sqlancer/presto/ast/PrestoAtTimeZoneOperator.java b/src/sqlancer/presto/ast/PrestoAtTimeZoneOperator.java new file mode 100644 index 000000000..ac73a0a79 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoAtTimeZoneOperator.java @@ -0,0 +1,20 @@ +package sqlancer.presto.ast; + +public class PrestoAtTimeZoneOperator implements PrestoExpression { + + private final PrestoExpression expr; + private final PrestoExpression timeZone; + + public PrestoAtTimeZoneOperator(PrestoExpression expr, PrestoExpression timeZone) { + this.expr = expr; + this.timeZone = timeZone; + } + + public PrestoExpression getExpr() { + return expr; + } + + public PrestoExpression getTimeZone() { + return timeZone; + } +} diff --git a/src/sqlancer/presto/ast/PrestoBetweenOperation.java b/src/sqlancer/presto/ast/PrestoBetweenOperation.java new file mode 100644 index 000000000..394a5f675 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoBetweenOperation.java @@ -0,0 +1,10 @@ +package sqlancer.presto.ast; + +import sqlancer.common.ast.newast.NewBetweenOperatorNode; + +public class PrestoBetweenOperation extends NewBetweenOperatorNode implements PrestoExpression { + public PrestoBetweenOperation(PrestoExpression left, PrestoExpression middle, PrestoExpression right, + boolean isTrue) { + super(left, middle, right, isTrue); + } +} diff --git a/src/sqlancer/presto/ast/PrestoBinaryOperation.java b/src/sqlancer/presto/ast/PrestoBinaryOperation.java new file mode 100644 index 000000000..90f83f31a --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoBinaryOperation.java @@ -0,0 +1,10 @@ +package sqlancer.presto.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class PrestoBinaryOperation extends NewBinaryOperatorNode implements PrestoExpression { + public PrestoBinaryOperation(PrestoExpression left, PrestoExpression right, BinaryOperatorNode.Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/presto/ast/PrestoCaseOperation.java b/src/sqlancer/presto/ast/PrestoCaseOperation.java new file mode 100644 index 000000000..117c54bbb --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoCaseOperation.java @@ -0,0 +1,12 @@ +package sqlancer.presto.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewCaseOperatorNode; + +public class PrestoCaseOperation extends NewCaseOperatorNode implements PrestoExpression { + public PrestoCaseOperation(PrestoExpression switchCondition, List conditions, + List expressions, PrestoExpression elseExpr) { + super(switchCondition, conditions, expressions, elseExpr); + } +} diff --git a/src/sqlancer/presto/ast/PrestoCastFunction.java b/src/sqlancer/presto/ast/PrestoCastFunction.java new file mode 100644 index 000000000..14798267d --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoCastFunction.java @@ -0,0 +1,23 @@ +package sqlancer.presto.ast; + +import sqlancer.presto.PrestoSchema; + +public class PrestoCastFunction implements PrestoExpression { + + private final PrestoExpression expr; + private final PrestoSchema.PrestoCompositeDataType type; + + public PrestoCastFunction(PrestoExpression expr, PrestoSchema.PrestoCompositeDataType type) { + this.expr = expr; + this.type = type; + } + + public PrestoExpression getExpr() { + return expr; + } + + public PrestoSchema.PrestoCompositeDataType getType() { + return type; + } + +} diff --git a/src/sqlancer/presto/ast/PrestoColumnReference.java b/src/sqlancer/presto/ast/PrestoColumnReference.java new file mode 100644 index 000000000..8bc6db0ec --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoColumnReference.java @@ -0,0 +1,13 @@ +package sqlancer.presto.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.presto.PrestoSchema; + +public class PrestoColumnReference extends ColumnReferenceNode + implements PrestoExpression { + + public PrestoColumnReference(PrestoSchema.PrestoColumn column) { + super(column); + } + +} diff --git a/src/sqlancer/presto/ast/PrestoComparisonFunction.java b/src/sqlancer/presto/ast/PrestoComparisonFunction.java new file mode 100644 index 000000000..1125c3e35 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoComparisonFunction.java @@ -0,0 +1,81 @@ +package sqlancer.presto.ast; + +import java.util.ArrayList; + +import sqlancer.Randomly; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoCompositeDataType; +import sqlancer.presto.PrestoSchema.PrestoDataType; + +public enum PrestoComparisonFunction implements PrestoFunction { + + // comparison + + // Returns the largest of the provided values. + // → [same as input] + GREATEST("greatest", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoSchema.PrestoCompositeDataType returnType) { + return PrestoDataType.getOrderableTypes().contains(returnType.getPrimitiveDataType()); + } + + @Override + public int getNumberOfArguments() { + return -1; + } + + @Override + public PrestoSchema.PrestoDataType[] getArgumentTypes(PrestoSchema.PrestoCompositeDataType returnType) { + return new PrestoSchema.PrestoDataType[] { returnType.getPrimitiveDataType() }; + } + }, + // Returns the smallest of the provided values. + // → [same as input]# + LEAST("least", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoSchema.PrestoCompositeDataType returnType) { + return PrestoDataType.getOrderableTypes().contains(returnType.getPrimitiveDataType()); + } + + @Override + public int getNumberOfArguments() { + return -1; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + ArrayList prestoDataTypes = new ArrayList<>(); + long no = Randomly.getNotCachedInteger(2, 10); + for (int i = 0; i < no; i++) { + prestoDataTypes.add(returnType.getPrimitiveDataType()); + } + return prestoDataTypes.toArray(new PrestoDataType[0]); + } + }; + + private final PrestoDataType returnType; + private final PrestoDataType[] argumentTypes; + private final String functionName; + + PrestoComparisonFunction(String functionName, PrestoDataType returnType, PrestoDataType... argumentTypes) { + this.functionName = functionName; + this.returnType = returnType; + this.argumentTypes = argumentTypes.clone(); + } + + @Override + public String getFunctionName() { + return functionName; + } + + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return this.returnType == returnType.getPrimitiveDataType(); + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return argumentTypes.clone(); + } + +} diff --git a/src/sqlancer/presto/ast/PrestoConditionalFunction.java b/src/sqlancer/presto/ast/PrestoConditionalFunction.java new file mode 100644 index 000000000..e603a7e5d --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoConditionalFunction.java @@ -0,0 +1,95 @@ +package sqlancer.presto.ast; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.presto.PrestoSchema.PrestoCompositeDataType; +import sqlancer.presto.PrestoSchema.PrestoDataType; + +public enum PrestoConditionalFunction implements PrestoFunction { + + // Conditional functions + IF_TRUE("if", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return true; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { PrestoDataType.BOOLEAN, returnType.getPrimitiveDataType() }; + } + }, + + IF_TRUE_FALSE("if", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return true; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { PrestoDataType.BOOLEAN, returnType.getPrimitiveDataType(), + returnType.getPrimitiveDataType() }; + } + }, + + NULLIF("nullif", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return true; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { returnType.getPrimitiveDataType(), returnType.getPrimitiveDataType() }; + } + }, + + COALESCE("coalesce", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return true; + } + + @Override + public int getNumberOfArguments() { + return -1; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + List prestoDataTypes = new ArrayList<>(); + long no = Randomly.getNotCachedInteger(2, 10); + for (int i = 0; i < no; i++) { + prestoDataTypes.add(returnType.getPrimitiveDataType()); + } + return prestoDataTypes.toArray(new PrestoDataType[0]); + } + }; + + private final PrestoDataType returnType; + private final String functionName; + + PrestoConditionalFunction(String functionName, PrestoDataType returnType) { + this.functionName = functionName; + this.returnType = returnType; + } + + @Override + public String getFunctionName() { + return functionName; + } + + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return this.returnType == returnType.getPrimitiveDataType(); + } + + @Override + public int getNumberOfArguments() { + return getArgumentTypes(PrestoCompositeDataType.fromDataType(returnType)).length; + } + +} diff --git a/src/sqlancer/presto/ast/PrestoConstant.java b/src/sqlancer/presto/ast/PrestoConstant.java new file mode 100644 index 000000000..a469b983c --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoConstant.java @@ -0,0 +1,808 @@ +package sqlancer.presto.ast; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.text.DecimalFormat; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.presto.PrestoConstantUtils; +import sqlancer.presto.PrestoSchema; + +public abstract class PrestoConstant implements PrestoExpression { + + private static final String[] TIME_ZONES = { "Africa/Abidjan", "Africa/Accra", "Africa/Addis_Ababa", + "Africa/Algiers", "Africa/Asmara", "Africa/Asmera", "Africa/Bamako", "Africa/Bangui", "Africa/Banjul", + "Africa/Bissau", "Africa/Blantyre", "Africa/Brazzaville", "Africa/Bujumbura", "Africa/Cairo", + "Africa/Casablanca", "Africa/Ceuta", "Africa/Conakry", "Africa/Dakar", "Africa/Dar_es_Salaam", + "Africa/Djibouti", "Africa/Douala", "Africa/El_Aaiun", "Africa/Freetown", "Africa/Gaborone", + "Africa/Harare", "Africa/Johannesburg", "Africa/Juba", "Africa/Kampala", "Africa/Khartoum", "Africa/Kigali", + "Africa/Kinshasa", "Africa/Lagos", "Africa/Libreville", "Africa/Lome", "Africa/Luanda", "Africa/Lubumbashi", + "Africa/Lusaka", "Africa/Malabo", "Africa/Maputo", "Africa/Maseru", "Africa/Mbabane", "Africa/Mogadishu", + "Africa/Monrovia", "Africa/Nairobi", "Africa/Ndjamena", "Africa/Niamey", "Africa/Nouakchott", + "Africa/Ouagadougou", "Africa/Porto-Novo", "Africa/Sao_Tome", "Africa/Timbuktu", "Africa/Tripoli", + "Africa/Tunis", "Africa/Windhoek", "America/Adak", "America/Anchorage", "America/Anguilla", + "America/Antigua", "America/Araguaina", "America/Argentina/Buenos_Aires", "America/Argentina/Catamarca", + "America/Argentina/ComodRivadavia", "America/Argentina/Cordoba", "America/Argentina/Jujuy", + "America/Argentina/La_Rioja", "America/Argentina/Mendoza", "America/Argentina/Rio_Gallegos", + "America/Argentina/Salta", "America/Argentina/San_Juan", "America/Argentina/San_Luis", + "America/Argentina/Tucuman", "America/Argentina/Ushuaia", "America/Aruba", "America/Asuncion", + "America/Atikokan", "America/Atka", "America/Bahia", "America/Barbados", "America/Belem", "America/Belize", + "America/Blanc-Sablon", "America/Boa_Vista", "America/Bogota", "America/Boise", "America/Buenos_Aires", + "America/Cambridge_Bay", "America/Campo_Grande", "America/Cancun", "America/Caracas", "America/Catamarca", + "America/Cayenne", "America/Cayman", "America/Chicago", "America/Chihuahua", "America/Coral_Harbour", + "America/Cordoba", "America/Costa_Rica", "America/Creston", "America/Cuiaba", "America/Curacao", + "America/Danmarkshavn", "America/Dawson", "America/Dawson_Creek", "America/Denver", "America/Detroit", + "America/Dominica", "America/Edmonton", "America/Eirunepe", "America/El_Salvador", "America/Ensenada", + "America/Fort_Nelson", "America/Fort_Wayne", "America/Fortaleza", "America/Glace_Bay", "America/Godthab", + "America/Goose_Bay", "America/Grand_Turk", "America/Grenada", "America/Guadeloupe", "America/Guatemala", + "America/Guayaquil", "America/Guyana", "America/Halifax", "America/Havana", "America/Indiana/Indianapolis", + "America/Indiana/Knox", "America/Indiana/Marengo", "America/Indiana/Petersburg", + "America/Indiana/Tell_City", "America/Indiana/Vevay", "America/Indiana/Vincennes", + "America/Indiana/Winamac", "America/Indianapolis", "America/Inuvik", "America/Iqaluit", "America/Jamaica", + "America/Jujuy", "America/Juneau", "America/Kentucky/Louisville", "America/Kentucky/Monticello", + "America/Knox_IN", "America/Kralendijk", "America/La_Paz", "America/Lima", "America/Los_Angeles", + "America/Louisville", "America/Lower_Princes", "America/Maceio", "America/Managua", "America/Manaus", + "America/Marigot", "America/Martinique", "America/Matamoros", "America/Mendoza", "America/Menominee", + "America/Merida", "America/Metlakatla", "America/Mexico_City", "America/Miquelon", "America/Moncton", + "America/Monterrey", "America/Montevideo", "America/Montreal", "America/Montserrat", "America/Nassau", + "America/New_York", "America/Nipigon", "America/Nome", "America/Noronha", "America/North_Dakota/Beulah", + "America/North_Dakota/Center", "America/North_Dakota/New_Salem", "America/Nuuk", "America/Ojinaga", + "America/Panama", "America/Pangnirtung", "America/Paramaribo", "America/Phoenix", "America/Port-au-Prince", + "America/Port_of_Spain", "America/Porto_Acre", "America/Porto_Velho", "America/Puerto_Rico", + "America/Punta_Arenas", "America/Rainy_River", "America/Rankin_Inlet", "America/Recife", "America/Regina", + "America/Resolute", "America/Rio_Branco", "America/Rosario", "America/Santa_Isabel", "America/Santarem", + "America/Santiago", "America/Santo_Domingo", "America/Sao_Paulo", "America/Scoresbysund", + "America/Shiprock", "America/Sitka", "America/St_Barthelemy", "America/St_Johns", "America/St_Kitts", + "America/St_Lucia", "America/St_Thomas", "America/St_Vincent", "America/Swift_Current", + "America/Tegucigalpa", "America/Thule", "America/Thunder_Bay", "America/Tijuana", "America/Toronto", + "America/Tortola", "America/Vancouver", "America/Virgin", "America/Whitehorse", "America/Winnipeg", + "America/Yakutat", "America/Yellowknife", "Antarctica/Casey", "Antarctica/Davis", + "Antarctica/DumontDUrville", "Antarctica/Macquarie", "Antarctica/Mawson", "Antarctica/McMurdo", + "Antarctica/Palmer", "Antarctica/Rothera", "Antarctica/South_Pole", "Antarctica/Syowa", "Antarctica/Troll", + "Antarctica/Vostok", "Arctic/Longyearbyen", "Asia/Aden", "Asia/Almaty", "Asia/Amman", "Asia/Anadyr", + "Asia/Aqtau", "Asia/Aqtobe", "Asia/Ashgabat", "Asia/Ashkhabad", "Asia/Atyrau", "Asia/Baghdad", + "Asia/Bahrain", "Asia/Baku", "Asia/Bangkok", "Asia/Barnaul", "Asia/Beirut", "Asia/Bishkek", "Asia/Brunei", + "Asia/Calcutta", "Asia/Chita", "Asia/Choibalsan", "Asia/Chongqing", "Asia/Chungking", "Asia/Colombo", + "Asia/Dacca", "Asia/Dhaka", "Asia/Dili", "Asia/Dubai", "Asia/Dushanbe", "Asia/Famagusta", "Asia/Gaza", + "Asia/Harbin", "Asia/Hebron", "Asia/Ho_Chi_Minh", "Asia/Hong_Kong", "Asia/Hovd", "Asia/Irkutsk", + "Asia/Istanbul", "Asia/Jakarta", "Asia/Jayapura", "Asia/Jerusalem", "Asia/Kabul", "Asia/Kamchatka", + "Asia/Karachi", "Asia/Kashgar", "Asia/Kathmandu", "Asia/Katmandu", "Asia/Khandyga", "Asia/Kolkata", + "Asia/Krasnoyarsk", "Asia/Kuala_Lumpur", "Asia/Kuching", "Asia/Kuwait", "Asia/Macao", "Asia/Macau", + "Asia/Magadan", "Asia/Makassar", "Asia/Manila", "Asia/Muscat", "Asia/Nicosia", "Asia/Novokuznetsk", + "Asia/Novosibirsk", "Asia/Omsk", "Asia/Oral", "Asia/Phnom_Penh", "Asia/Pontianak", "Asia/Pyongyang", + "Asia/Qatar", "Asia/Qostanay", "Asia/Qyzylorda", "Asia/Rangoon", "Asia/Riyadh", "Asia/Saigon", + "Asia/Sakhalin", "Asia/Samarkand", "Asia/Seoul", "Asia/Shanghai", "Asia/Singapore", "Asia/Srednekolymsk", + "Asia/Taipei", "Asia/Tashkent", "Asia/Tbilisi", "Asia/Tehran", "Asia/Tel_Aviv", "Asia/Thimbu", + "Asia/Thimphu", "Asia/Tokyo", "Asia/Tomsk", "Asia/Ujung_Pandang", "Asia/Ulaanbaatar", "Asia/Ulan_Bator", + "Asia/Urumqi", "Asia/Ust-Nera", "Asia/Vientiane", "Asia/Vladivostok", "Asia/Yakutsk", "Asia/Yangon", + "Asia/Yekaterinburg", "Asia/Yerevan", "Atlantic/Azores", "Atlantic/Bermuda", "Atlantic/Canary", + "Atlantic/Cape_Verde", "Atlantic/Faeroe", "Atlantic/Faroe", "Atlantic/Jan_Mayen", "Atlantic/Madeira", + "Atlantic/Reykjavik", "Atlantic/South_Georgia", "Atlantic/St_Helena", "Atlantic/Stanley", "Australia/ACT", + "Australia/Adelaide", "Australia/Brisbane", "Australia/Broken_Hill", "Australia/Canberra", + "Australia/Currie", "Australia/Darwin", "Australia/Eucla", "Australia/Hobart", "Australia/LHI", + "Australia/Lindeman", "Australia/Lord_Howe", "Australia/Melbourne", "Australia/North", "Australia/Perth", + "Australia/Queensland", "Australia/South", "Australia/Sydney", "Australia/Tasmania", "Australia/Victoria", + "Australia/West", "Australia/Yancowinna", "Brazil/Acre", "Brazil/DeNoronha", "Brazil/East", "Brazil/West", + "CET", "CST6CDT", "Canada/Atlantic", "Canada/Central", "Canada/Eastern", "Canada/Mountain", + "Canada/Newfoundland", "Canada/Pacific", "Canada/Saskatchewan", "Canada/Yukon", "Chile/Continental", + "Chile/EasterIsland", "Cuba", "EET", "EST5EDT", "Egypt", "Eire", "Etc/GMT", "Etc/GMT+0", "Etc/GMT+1", + "Etc/GMT+10", "Etc/GMT+11", "Etc/GMT+12", "Etc/GMT+2", "Etc/GMT+3", "Etc/GMT+4", "Etc/GMT+5", "Etc/GMT+6", + "Etc/GMT+7", "Etc/GMT+8", "Etc/GMT+9", "Etc/GMT-0", "Etc/GMT-1", "Etc/GMT-10", "Etc/GMT-11", "Etc/GMT-12", + "Etc/GMT-13", "Etc/GMT-14", "Etc/GMT-2", "Etc/GMT-3", "Etc/GMT-4", "Etc/GMT-5", "Etc/GMT-6", "Etc/GMT-7", + "Etc/GMT-8", "Etc/GMT-9", "Etc/GMT0", "Etc/Greenwich", "Etc/UCT", "Etc/UTC", "Etc/Universal", "Etc/Zulu", + "Europe/Amsterdam", "Europe/Andorra", "Europe/Astrakhan", "Europe/Athens", "Europe/Belfast", + "Europe/Belgrade", "Europe/Berlin", "Europe/Bratislava", "Europe/Brussels", "Europe/Bucharest", + "Europe/Budapest", "Europe/Busingen", "Europe/Chisinau", "Europe/Copenhagen", "Europe/Dublin", + "Europe/Gibraltar", "Europe/Guernsey", "Europe/Helsinki", "Europe/Isle_of_Man", "Europe/Istanbul", + "Europe/Jersey", "Europe/Kaliningrad", "Europe/Kiev", "Europe/Kirov", "Europe/Lisbon", "Europe/Ljubljana", + "Europe/London", "Europe/Luxembourg", "Europe/Madrid", "Europe/Malta", "Europe/Mariehamn", "Europe/Minsk", + "Europe/Monaco", "Europe/Moscow", "Europe/Nicosia", "Europe/Oslo", "Europe/Paris", "Europe/Podgorica", + "Europe/Prague", "Europe/Riga", "Europe/Rome", "Europe/Samara", "Europe/San_Marino", "Europe/Sarajevo", + "Europe/Saratov", "Europe/Simferopol", "Europe/Skopje", "Europe/Sofia", "Europe/Stockholm", "Europe/Tirane", + "Europe/Tiraspol", "Europe/Ulyanovsk", "Europe/Uzhgorod", "Europe/Vaduz", "Europe/Vatican", "Europe/Vienna", + "Europe/Vilnius", "Europe/Volgograd", "Europe/Warsaw", "Europe/Zagreb", "Europe/Zaporozhye", + "Europe/Zurich", "GB", "GB-Eire", "GMT", "GMT0", "Greenwich", "Hongkong", "Iceland", "Indian/Antananarivo", + "Indian/Chagos", "Indian/Christmas", "Indian/Cocos", "Indian/Comoro", "Indian/Kerguelen", "Indian/Mahe", + "Indian/Maldives", "Indian/Mauritius", "Indian/Mayotte", "Indian/Reunion", "Iran", "Israel", "Jamaica", + "Japan", "Kwajalein", "Libya", "MET", "MST7MDT", "Mexico/General", "NZ", "NZ-CHAT", "Navajo", "PRC", + "PST8PDT", "Pacific/Apia", "Pacific/Auckland", "Pacific/Bougainville", "Pacific/Chatham", "Pacific/Chuuk", + "Pacific/Easter", "Pacific/Efate", "Pacific/Enderbury", "Pacific/Fakaofo", "Pacific/Fiji", + "Pacific/Funafuti", "Pacific/Galapagos", "Pacific/Gambier", "Pacific/Guadalcanal", "Pacific/Guam", + "Pacific/Honolulu", "Pacific/Johnston", "Pacific/Kiritimati", "Pacific/Kosrae", "Pacific/Kwajalein", + "Pacific/Majuro", "Pacific/Marquesas", "Pacific/Midway", "Pacific/Nauru", "Pacific/Niue", "Pacific/Norfolk", + "Pacific/Noumea", "Pacific/Pago_Pago", "Pacific/Palau", "Pacific/Pitcairn", "Pacific/Pohnpei", + "Pacific/Ponape", "Pacific/Port_Moresby", "Pacific/Rarotonga", "Pacific/Saipan", "Pacific/Samoa", + "Pacific/Tahiti", "Pacific/Tarawa", "Pacific/Tongatapu", "Pacific/Truk", "Pacific/Wake", "Pacific/Wallis", + "Pacific/Yap", "Poland", "Portugal", "ROK", "Singapore", "Turkey", "UCT", "US/Alaska", "US/Aleutian", + "US/Arizona", "US/Central", "US/East-Indiana", "US/Eastern", "US/Hawaii", "US/Indiana-Starke", + "US/Michigan", "US/Mountain", "US/Pacific", "US/Samoa", "UTC", "Universal", "W-SU", "WET", "Zulu" }; + private static final String FALSE = "false"; + private static final String TRUE = "true"; + + private PrestoConstant() { + } + + public static PrestoExpression createStringConstant(String text) { + return new PrestoTextConstant(text); + } + + public static PrestoExpression createStringConstant(String text, int size) { + return new PrestoTextConstant(text, size); + } + + public static PrestoExpression createJsonConstant() { + return new PrestoJsonConstant(); + } + + public static PrestoExpression createFloatConstant(PrestoSchema.PrestoCompositeDataType type, double val) { + assert type.getSize() == 4; + float floatValue = (float) val; + return new PrestoFloatConstant(floatValue); + } + + public static PrestoExpression createFloatConstant(double val) { + return new PrestoFloatConstant(val); + } + + public static PrestoExpression createDecimalConstant(double val) { + return new PrestoDecimalConstant(val); + } + + public static PrestoExpression createDecimalConstant(PrestoSchema.PrestoCompositeDataType type, double val) { + int scale = type.getScale(); + int precision = type.getSize(); + BigDecimal finalBD = PrestoConstantUtils.getDecimal(val, scale, precision); + return new PrestoDecimalConstant(finalBD.doubleValue()); + } + + public static PrestoExpression createIntConstant(long val) { + return new PrestoIntConstant(val); + } + + public static PrestoExpression createIntConstant(PrestoSchema.PrestoCompositeDataType type, long val, + boolean castInteger) { + PrestoIntConstant intConstant; + assert List.of(1, 2, 4, 8).contains(type.getSize()); + switch (type.getSize()) { + case 1: + intConstant = new PrestoIntConstant((byte) val); + break; + case 2: + intConstant = new PrestoIntConstant((short) val); + break; + case 4: + intConstant = new PrestoIntConstant((int) val); + break; + default: + intConstant = new PrestoIntConstant(val); + } + if (castInteger) { + return new PrestoCastFunction(intConstant, type); + } else { + return intConstant; + } + } + + public static PrestoExpression createNullConstant() { + return new PrestoNullConstant(); + } + + public static PrestoExpression createBooleanConstant(boolean val) { + return new PrestoBooleanConstant(val); + } + + public static PrestoExpression createDateConstant(long integer) { + return new PrestoDateConstant(integer); + } + + public static PrestoExpression createTimeConstant(long integer) { + return new PrestoTimeConstant(integer); + } + + public static PrestoExpression createTimeWithTimeZoneConstant(long integer) { + return new PrestoTimeWithTimeZoneConstant(integer); + } + + public static PrestoExpression createTimestampWithTimeZoneConstant(long integer) { + return new PrestoTimestampWithTimezoneConstant(integer); + } + + public static PrestoExpression createIntervalDayToSecond(long integer) { + return new PrestoIntervalDayToSecondConstant(); + } + + public static PrestoExpression createIntervalYearToMonth(long integer) { + return new PrestoIntervalYearToMonthConstant(); + } + + public static PrestoExpression createTimestampConstant(long integer) { + return new PrestoTimestampConstant(integer); + } + + public static PrestoExpression createVarbinaryConstant(String string) { + return new PrestoVarbinaryConstant(string); + } + + public static PrestoExpression createTimezoneConstant() { + String string = Randomly.fromOptions(TIME_ZONES); + return new PrestoTextConstant(string); + } + + public static PrestoExpression createArrayConstant(PrestoSchema.PrestoCompositeDataType type) { + PrestoSchema.PrestoCompositeDataType elementType = type.getElementType(); + long size = Randomly.getNotCachedInteger(0, 10); + + List elements = new ArrayList<>(); + for (int i = 0; i <= size; i++) { + if (elementType.getPrimitiveDataType() == PrestoSchema.PrestoDataType.ARRAY) { + elements.add(createArrayConstant(elementType)); + } else { + elements.add(generateConstant(elementType, false)); + } + } + return new PrestoArrayConstant(elements); + } + + public static PrestoExpression createMapConstant(PrestoSchema.PrestoCompositeDataType type) { + PrestoSchema.PrestoCompositeDataType elementType = type.getElementType(); + long size = Randomly.getNotCachedInteger(0, 10); + + List elements = new ArrayList<>(); + for (int i = 0; i <= size; i++) { + if (elementType.getPrimitiveDataType() == PrestoSchema.PrestoDataType.ARRAY) { + elements.add(createArrayConstant(elementType)); + } else { + elements.add(generateConstant(elementType, false)); + } + } + return new PrestoArrayConstant(elements); + } + + public static PrestoExpression generateConstant(PrestoSchema.PrestoCompositeDataType type, boolean castInteger) { + Randomly randomly = new Randomly(); + switch (type.getPrimitiveDataType()) { + case ARRAY: + return PrestoConstant.createArrayConstant(type); + case NULL: + return PrestoConstant.createNullConstant(); + case CHAR: + return PrestoConstant.PrestoTextConstant.createStringConstant(randomly.getAlphabeticChar(), type.getSize()); + case VARCHAR: + return PrestoConstant.PrestoTextConstant.createStringConstant(randomly.getString(), type.getSize()); + case VARBINARY: + return PrestoConstant.createVarbinaryConstant(randomly.getString()); + case JSON: + return PrestoConstant.PrestoJsonConstant.createJsonConstant(); + case TIME: + return PrestoConstant.createTimeConstant(randomly.getLong(0, System.currentTimeMillis())); + case TIME_WITH_TIME_ZONE: + return PrestoConstant.createTimeWithTimeZoneConstant(randomly.getLong(0, System.currentTimeMillis())); + case TIMESTAMP: + return PrestoConstant.createTimestampConstant(randomly.getLong(0, System.currentTimeMillis())); + case TIMESTAMP_WITH_TIME_ZONE: + return PrestoConstant.createTimestampWithTimeZoneConstant(randomly.getLong(0, System.currentTimeMillis())); + case INTERVAL_YEAR_TO_MONTH: + return PrestoConstant.createIntervalYearToMonth(randomly.getLong(0, System.currentTimeMillis())); + case INTERVAL_DAY_TO_SECOND: + return PrestoConstant.createIntervalDayToSecond(randomly.getLong(0, System.currentTimeMillis())); + case INT: + return PrestoConstant.PrestoIntConstant.createIntConstant(type, Randomly.getNonCachedInteger(), + castInteger); + case FLOAT: + return PrestoConstant.PrestoFloatConstant.createFloatConstant(randomly.getDouble()); + case BOOLEAN: + return PrestoConstant.PrestoBooleanConstant.createBooleanConstant(Randomly.getBoolean()); + case DATE: + return PrestoConstant.createDateConstant(randomly.getLong(0, System.currentTimeMillis())); + case DECIMAL: + return PrestoConstant.createDecimalConstant(type, randomly.getLong(0, System.currentTimeMillis())); + default: + throw new AssertionError("Unknown type: " + type); + } + } + + public boolean isNull() { + return false; + } + + public boolean isInt() { + return false; + } + + public boolean isBoolean() { + return false; + } + + public boolean isArray() { + return false; + } + + public boolean isString() { + return false; + } + + public boolean isFloat() { + return false; + } + + public boolean asBoolean() { + throw new UnsupportedOperationException(this.toString()); + } + + public long asInt() { + throw new UnsupportedOperationException(this.toString()); + } + + public String asString() { + throw new UnsupportedOperationException(this.toString()); + } + + public double asFloat() { + throw new UnsupportedOperationException(this.toString()); + } + + public static class PrestoNullConstant extends PrestoConstant { + + @Override + public String toString() { + return "NULL"; + } + + @Override + public boolean isNull() { + return true; + } + + } + + public static class PrestoIntConstant extends PrestoConstant { + + private final long value; + + public PrestoIntConstant(long value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + public long getValue() { + return value; + } + + @Override + public boolean isInt() { + return true; + } + + } + + public static class PrestoFloatConstant extends PrestoConstant { + + private final double value; + + public PrestoFloatConstant(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "infinity()"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "-infinity()"; + } + return String.valueOf(value); + } + + @Override + public boolean isFloat() { + return true; + } + + @Override + public double asFloat() { + return value; + } + + } + + public static class PrestoDecimalConstant extends PrestoConstant { + + private static final DecimalFormat DECIMAL_FORMAT = new DecimalFormat("###0.0000"); + + private final double value; + + public PrestoDecimalConstant(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "'+Inf'"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "'-Inf'"; + } + return DECIMAL_FORMAT.format(value); + } + + @Override + public double asFloat() { + return value; + } + + } + + public static class PrestoTextConstant extends PrestoConstant { + + private final String value; + + public PrestoTextConstant(String value) { + this.value = value; + } + + public PrestoTextConstant(String value, int size) { + this.value = value.substring(0, Math.min(value.length(), size)); + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "'" + value.replace("'", "''") + "'"; + } + + } + + public static class PrestoVarbinaryConstant extends PrestoConstant { + + private final String value; + + public PrestoVarbinaryConstant(String value) { + this.value = value.replace("'", ""); + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return String.format("CAST ('%s' AS VARBINARY)", value); + } + + } + + public static class PrestoJsonConstant extends PrestoConstant { + + private final String value; + + public PrestoJsonConstant() { + Randomly rand = new Randomly(); + JsonValueType jvt = Randomly.fromOptions(JsonValueType.values()); + String val; + switch (jvt) { + case NULL: + val = "null"; + value = "{\"val\":" + val + "}"; + break; + case FALSE: + val = FALSE; + value = "{\"val\":" + val + "}"; + break; + case TRUE: + val = TRUE; + value = "{\"val\":" + val + "}"; + break; + case STRING: + String randString = rand.getString(); + String string = randString.substring(0, Math.min(randString.length(), 250)); + string = string.replace("'", ""); + // https://www.rfc-editor.org/rfc/rfc8259#page-8 + string = PrestoConstantUtils.removeAllControlChars(string); + string = string.replace("\\", "\\\\"); + + value = "{\"val\": \"" + string + "\"}"; + break; + case NUMBER: + if (Randomly.getBoolean()) { + int no = (int) rand.getInteger(); + val = String.valueOf(no); + } else { + double no = rand.getDouble(); + val = String.valueOf(no); + } + value = "{\"val\": " + val + "}"; + break; + case ARRAY: + value = "{\"employees\":[\"John\", \"Anna\", \"Peter\"]}"; + break; + case OBJECT: + value = "{\"employee\":{\"name\":\"John\", \"age\":30, \"city\":\"New York\"}}"; + break; + default: + value = "{}"; + } + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "JSON '" + value + "'"; + } + + private enum JsonValueType { + OBJECT, ARRAY, NUMBER, STRING, TRUE, FALSE, NULL + } + + } + + public static class PrestoDateConstant extends PrestoConstant { + + private final String textRepresentation; + + public PrestoDateConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + textRepresentation = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepresentation; + } + + @Override + public String toString() { + return String.format("DATE '%s'", textRepresentation); + } + + } + + public static class PrestoTimeConstant extends PrestoConstant { + + public final String textRepresentation; + + public PrestoTimeConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("HH:mm:ss.SSS"); + textRepresentation = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepresentation; + } + + @Override + public String toString() { + return String.format("TIME '%s'", textRepresentation); + } + + } + + public static class PrestoTimeWithTimeZoneConstant extends PrestoConstant { + + private final String textRepresentation; + private final String timeZone; + + public PrestoTimeWithTimeZoneConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("HH:mm:ss.SSS"); + textRepresentation = dateFormat.format(timestamp); + this.timeZone = Randomly.fromOptions(TIME_ZONES); + } + + public String getValue() { + return textRepresentation; + } + + @Override + public String toString() { + return String.format("TIME '%s %s'", textRepresentation, timeZone); + } + + } + + public static class PrestoTimestampConstant extends PrestoConstant { + + private final String textRepresentation; + + public PrestoTimestampConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + this.textRepresentation = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepresentation; + } + + @Override + public String toString() { + return String.format("TIMESTAMP '%s'", textRepresentation); + } + + } + + public static class PrestoTimestampWithTimezoneConstant extends PrestoConstant { + + private final String textRepresentation; + private final String timeZone; + + public PrestoTimestampWithTimezoneConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + this.textRepresentation = dateFormat.format(timestamp); + this.timeZone = Randomly.fromOptions(TIME_ZONES); + } + + public String getValue() { + return textRepresentation; + } + + @Override + public String toString() { + return String.format("TIMESTAMP '%s %s'", textRepresentation, timeZone); + } + + } + + public static class PrestoIntervalDayToSecondConstant extends PrestoConstant { + + private final String textRepresentation; + private final Interval fromInterval; + + public PrestoIntervalDayToSecondConstant() { + this.fromInterval = Randomly.fromOptions(Interval.values()); + SimpleDateFormat dateFormat = new SimpleDateFormat("dd HH:mm:ss"); + switch (fromInterval) { + case DAY: + dateFormat = new SimpleDateFormat("dd"); + break; + case HOUR: + dateFormat = new SimpleDateFormat("HH"); + break; + case MINUTE: + dateFormat = new SimpleDateFormat("mm"); + break; + case SECOND: + dateFormat = new SimpleDateFormat("ss"); + break; + default: + break; + } + + Randomly rand = new Randomly(); + + Timestamp timestamp = new Timestamp(rand.getLong(0, System.currentTimeMillis())); + this.textRepresentation = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepresentation; + } + + @Override + public String toString() { + // if (toInterval == null) { + return String.format("INTERVAL '%s' %s", textRepresentation, fromInterval.name()); + // } else { + // return String.format("INTERVAL '%s' %s TO %s", textRepresentation, fromInterval, toInterval); + // } + } + + private enum Interval { + DAY, HOUR, MINUTE, SECOND + } + + } + + public static class PrestoIntervalYearToMonthConstant extends PrestoConstant { + + public String textRepresentation; + private final Interval fromInterval; + + public PrestoIntervalYearToMonthConstant() { + fromInterval = Randomly.fromOptions(Interval.values()); + SimpleDateFormat dateFormat; + switch (fromInterval) { + case YEAR: + dateFormat = new SimpleDateFormat("yyyy"); + break; + case MONTH: + dateFormat = new SimpleDateFormat("MM"); + break; + default: + dateFormat = new SimpleDateFormat("yyyy-MM"); + } + + Randomly rand = new Randomly(); + + Timestamp timestamp = new Timestamp(rand.getLong(0, System.currentTimeMillis())); + textRepresentation = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepresentation; + } + + @Override + public String toString() { + return String.format("INTERVAL '%s' %s", textRepresentation, fromInterval.name()); + } + + private enum Interval { + YEAR, MONTH + } + + } + + public static class PrestoBooleanConstant extends PrestoConstant { + + private final boolean value; + + public PrestoBooleanConstant(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + @Override + public boolean asBoolean() { + return value; + } + + @Override + public boolean isBoolean() { + return true; + } + + } + + public static class PrestoArrayConstant extends PrestoConstant { + + private final List elements; + + public PrestoArrayConstant(List elements) { + this.elements = new ArrayList<>(elements); + } + + @Override + public boolean isArray() { + return true; + } + + @Override + public String toString() { + return "ARRAY[" + elements.stream().map(Object::toString).collect(Collectors.joining(", ")) + "]"; + } + + } + +} diff --git a/src/sqlancer/presto/ast/PrestoDateFunction.java b/src/sqlancer/presto/ast/PrestoDateFunction.java new file mode 100644 index 000000000..a8e00a3c0 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoDateFunction.java @@ -0,0 +1,523 @@ +package sqlancer.presto.ast; + +import sqlancer.presto.PrestoSchema.PrestoCompositeDataType; +import sqlancer.presto.PrestoSchema.PrestoDataType; + +public enum PrestoDateFunction implements PrestoFunction { + + // Date and Time Functions# + // Returns the current date as of the start of the query. + CURRENT_DATE("current_date", PrestoDataType.DATE), + + // Returns the current time as of the start of the query. + CURRENT_TIME("current_time", PrestoDataType.TIME_WITH_TIME_ZONE), + + // Returns the current timestamp as of the start of the query. + CURRENT_TIMESTAMP("current_timestamp", PrestoDataType.TIMESTAMP_WITH_TIME_ZONE), + + // Returns the current time zone in the format defined by IANA (e.g., America/Los_Angeles) or as fixed offset from + // UTC (e.g., +08:35) + CURRENT_TIMEZONE("current_timezone", PrestoDataType.VARCHAR), + + // This is an alias for CAST(x AS date). + DATE("date", PrestoDataType.DATE, PrestoDataType.DATE, PrestoDataType.INT, PrestoDataType.VARCHAR), + + // Returns the last day of the month. + LAST_DAY_OF_MONTH("last_day_of_month", PrestoDataType.DATE, PrestoDataType.DATE), + + // Parses the ISO 8601 formatted string into a timestamp with time zone. + FROM_ISO8601_TIMESTAMP("from_iso8601_timestamp", PrestoDataType.TIMESTAMP_WITH_TIME_ZONE, PrestoDataType.VARCHAR), + + // Parses the ISO 8601 formatted string into a date. + FROM_ISO8601_DATE("from_iso8601_date", PrestoDataType.DATE, PrestoDataType.VARCHAR), + + // Returns the UNIX timestamp unixtime as a timestamp. + FROM_UNIXTIME("from_unixtime", PrestoDataType.TIMESTAMP, PrestoDataType.INT), + + // Returns the UNIX timestamp unixtime as a timestamp with time zone using string for the time zone. + FROM_UNIXTIME_TIMEZONE("from_unixtime", PrestoDataType.TIMESTAMP_WITH_TIME_ZONE, PrestoDataType.INT, + PrestoDataType.VARCHAR) { + @Override + public boolean shouldPreserveOrderOfArguments() { + return true; + } + }, + + // Returns the UNIX timestamp unixtime as a timestamp with time zone using hours and minutes for the time zone + // offset. + FROM_UNIXTIME_HOURS_MINUTES("from_unixtime", PrestoDataType.TIMESTAMP_WITH_TIME_ZONE, PrestoDataType.INT, + PrestoDataType.INT) { + @Override + public boolean shouldPreserveOrderOfArguments() { + return true; + } + }, + + // Returns the current time as of the start of the query. -> time + LOCALTIME("localtime", PrestoDataType.TIME), + + // Returns the current timestamp as of the start of the query. -> timestamp + LOCALTIMESTAMP("localtimestamp", PrestoDataType.TIMESTAMP), + + // This is an alias for current_timestamp. → timestamp with time zone# + NOW("now", PrestoDataType.TIMESTAMP_WITH_TIME_ZONE), + + // Formats x as an ISO 8601 string. x can be date, timestamp, or timestamp with time zone. → varchar# + TO_ISO8601("to_iso8601", PrestoDataType.VARCHAR, PrestoDataType.DATE, PrestoDataType.TIMESTAMP, + PrestoDataType.TIMESTAMP_WITH_TIME_ZONE), + + // Returns the day-to-second interval as milliseconds. → bigint# + TO_MILLISECONDS("to_milliseconds", PrestoDataType.INT, PrestoDataType.INTERVAL_DAY_TO_SECOND), + TO_MILLISECONDS_2("to_milliseconds", PrestoDataType.INT, PrestoDataType.INTERVAL_YEAR_TO_MONTH), + + // Returns timestamp as a UNIX timestamp. → double# + TO_UNIXTIME("to_unixtime", PrestoDataType.FLOAT, PrestoDataType.TIMESTAMP), + TO_UNIXTIME_2("to_unixtime", PrestoDataType.FLOAT, PrestoDataType.TIMESTAMP_WITH_TIME_ZONE), + + // The following SQL-standard functions do not use parenthesis: + CURRENT_DATE_NA("current_date", PrestoDataType.DATE) { + @Override + public boolean isStandardFunction() { + return false; + } + }, + + CURRENT_TIME_NA("current_time", PrestoDataType.TIME) { + @Override + public boolean isStandardFunction() { + return false; + } + }, + + CURRENT_TIMESTAMP_NA("current_timestamp", PrestoDataType.TIMESTAMP) { + @Override + public boolean isStandardFunction() { + return false; + } + }, + + LOCALTIME_NA("localtime", PrestoDataType.TIME) { + @Override + public boolean isStandardFunction() { + return false; + } + }, + + LOCALTIMESTAMP_NA("localtimestamp", PrestoDataType.TIMESTAMP) { + @Override + public boolean isStandardFunction() { + return false; + } + }, + + // Truncation Function + // date_trunc(unit, x) → [same as input] + DATE_TRUNC_1("date_trunc", PrestoDataType.TIMESTAMP, PrestoDataType.VARCHAR, PrestoDataType.TIMESTAMP), + DATE_TRUNC_2("date_trunc", PrestoDataType.TIMESTAMP_WITH_TIME_ZONE, PrestoDataType.VARCHAR, + PrestoDataType.TIMESTAMP_WITH_TIME_ZONE), + DATE_TRUNC_3("date_trunc", PrestoDataType.DATE, PrestoDataType.VARCHAR, PrestoDataType.DATE), + DATE_TRUNC_4("date_trunc", PrestoDataType.TIME, PrestoDataType.VARCHAR, PrestoDataType.TIME); + + /* + * + * Interval Functions# The functions in this section support the following interval units: + * + * Unit + * + * Description + * + * millisecond + * + * Milliseconds + * + * second + * + * Seconds + * + * minute + * + * Minutes + * + * hour + * + * Hours + * + * day + * + * Days + * + * week + * + * Weeks + * + * month + * + * Months + * + * quarter + * + * Quarters of a year + * + * year + * + * Years + * + * date_add(unit, value, timestamp) → [same as input]# Adds an interval value of type unit to timestamp. Subtraction + * can be performed by using a negative value. + * + * date_diff(unit, timestamp1, timestamp2) → bigint# Returns timestamp2 - timestamp1 expressed in terms of unit. + * + * Duration Function# The parse_duration function supports the following units: + * + * Unit + * + * Description + * + * ns + * + * Nanoseconds + * + * us + * + * Microseconds + * + * ms + * + * Milliseconds + * + * s + * + * Seconds + * + * m + * + * Minutes + * + * h + * + * Hours + * + * d + * + * Days + * + * parse_duration(string) → interval# Parses string of format value unit into an interval, where value is fractional + * number of unit values: + * + * SELECT parse_duration('42.8ms'); -- 0 00:00:00.043 SELECT parse_duration('3.81 d'); -- 3 19:26:24.000 SELECT + * parse_duration('5m'); -- 0 00:05:00.000 MySQL Date Functions# The functions in this section use a format string + * that is compatible with the MySQL date_parse and str_to_date functions. The following table, based on the MySQL + * manual, describes the format specifiers: + * + * Specifier + * + * Description + * + * %a + * + * Abbreviated weekday name (Sun .. Sat) + * + * %b + * + * Abbreviated month name (Jan .. Dec) + * + * %c + * + * Month, numeric (1 .. 12) 4 + * + * %D + * + * Day of the month with English suffix (0th, 1st, 2nd, 3rd, …) + * + * %d + * + * Day of the month, numeric (01 .. 31) 4 + * + * %e + * + * Day of the month, numeric (1 .. 31) 4 + * + * %f + * + * Fraction of second (6 digits for printing: 000000 .. 999000; 1 - 9 digits for parsing: 0 .. 999999999) 1 + * + * %H + * + * Hour (00 .. 23) + * + * %h + * + * Hour (01 .. 12) + * + * %I + * + * Hour (01 .. 12) + * + * %i + * + * Minutes, numeric (00 .. 59) + * + * %j + * + * Day of year (001 .. 366) + * + * %k + * + * Hour (0 .. 23) + * + * %l + * + * Hour (1 .. 12) + * + * %M + * + * Month name (January .. December) + * + * %m + * + * Month, numeric (01 .. 12) 4 + * + * %p + * + * AM or PM + * + * %r + * + * Time, 12-hour (hh:mm:ss followed by AM or PM) + * + * %S + * + * Seconds (00 .. 59) + * + * %s + * + * Seconds (00 .. 59) + * + * %T + * + * Time, 24-hour (hh:mm:ss) + * + * %U + * + * Week (00 .. 53), where Sunday is the first day of the week + * + * %u + * + * Week (00 .. 53), where Monday is the first day of the week + * + * %V + * + * Week (01 .. 53), where Sunday is the first day of the week; used with %X + * + * %v + * + * Week (01 .. 53), where Monday is the first day of the week; used with %x + * + * %W + * + * Weekday name (Sunday .. Saturday) + * + * %w + * + * Day of the week (0 .. 6), where Sunday is the first day of the week 3 + * + * %X + * + * Year for the week where Sunday is the first day of the week, numeric, four digits; used with %V + * + * %x + * + * Year for the week, where Monday is the first day of the week, numeric, four digits; used with %v + * + * %Y + * + * Year, numeric, four digits + * + * %y + * + * Year, numeric (two digits) 2 + * + * %% + * + * A literal % character + * + * %x + * + * x, for any x not listed above + * + * 1 Timestamp is truncated to milliseconds. + * + * 2 When parsing, two-digit year format assumes range 1970 ... 2069, so “70” will result in year 1970 but “69” will + * produce 2069. + * + * 3 This specifier is not supported yet. Consider using day_of_week() (it uses 1-7 instead of 0-6). + * + * 4(1,2,3,4) This specifier does not support 0 as a month or day. + * + * Warning + * + * The following specifiers are not currently supported: %D %U %u %V %w %X + * + * date_format(timestamp, format) → varchar# Formats timestamp as a string using format. + * + * date_parse(string, format) → timestamp# Parses string into a timestamp using format. + * + * Java Date Functions# The functions in this section use a format string that is compatible with JodaTime’s + * DateTimeFormat pattern format. + * + * format_datetime(timestamp, format) → varchar# Formats timestamp as a string using format. + * + * parse_datetime(string, format) → timestamp with time zone# Parses string into a timestamp with time zone using + * format. + * + * Extraction Function# The extract function supports the following fields: + * + * Field + * + * Description + * + * YEAR + * + * year() + * + * QUARTER + * + * quarter() + * + * MONTH + * + * month() + * + * WEEK + * + * week() + * + * DAY + * + * day() + * + * DAY_OF_MONTH + * + * day() + * + * DAY_OF_WEEK + * + * day_of_week() + * + * DOW + * + * day_of_week() + * + * DAY_OF_YEAR + * + * day_of_year() + * + * DOY + * + * day_of_year() + * + * YEAR_OF_WEEK + * + * year_of_week() + * + * YOW + * + * year_of_week() + * + * HOUR + * + * hour() + * + * MINUTE + * + * minute() + * + * SECOND + * + * second() + * + * TIMEZONE_HOUR + * + * timezone_hour() + * + * TIMEZONE_MINUTE + * + * timezone_minute() + * + * The types supported by the extract function vary depending on the field to be extracted. Most fields support all + * date and time types. + * + * extract(field FROM x) → bigint# Returns field from x. + * + * Note + * + * This SQL-standard function uses special syntax for specifying the arguments. + * + * Convenience Extraction Functions# day(x) → bigint# Returns the day of the month from x. + * + * day_of_month(x) → bigint# This is an alias for day(). + * + * day_of_week(x) → bigint# Returns the ISO day of the week from x. The value ranges from 1 (Monday) to 7 (Sunday). + * + * day_of_year(x) → bigint# Returns the day of the year from x. The value ranges from 1 to 366. + * + * dow(x) → bigint# This is an alias for day_of_week(). + * + * doy(x) → bigint# This is an alias for day_of_year(). + * + * hour(x) → bigint# Returns the hour of the day from x. The value ranges from 0 to 23. + * + * millisecond(x) → bigint# Returns the millisecond of the second from x. + * + * minute(x) → bigint# Returns the minute of the hour from x. + * + * month(x) → bigint# Returns the month of the year from x. + * + * quarter(x) → bigint# Returns the quarter of the year from x. The value ranges from 1 to 4. + * + * second(x) → bigint# Returns the second of the minute from x. + * + * timezone_hour(timestamp) → bigint# Returns the hour of the time zone offset from timestamp. + * + * timezone_minute(timestamp) → bigint# Returns the minute of the time zone offset from timestamp. + * + * week(x) → bigint# Returns the ISO week of the year from x. The value ranges from 1 to 53. + * + * week_of_year(x) → bigint# This is an alias for week(). + * + * year(x) → bigint# Returns the year from x. + * + * year_of_week(x) → bigint# Returns the year of the ISO week from x. + * + * yow(x) → bigint# This is an alias for year_of_week(). + * + * + * + */ + + private final PrestoDataType returnType; + private final PrestoDataType[] argumentTypes; + private final String functionName; + + PrestoDateFunction(String functionName, PrestoDataType returnType, PrestoDataType... argumentTypes) { + this.functionName = functionName; + this.returnType = returnType; + this.argumentTypes = argumentTypes.clone(); + } + + @Override + public String getFunctionName() { + return functionName; + } + + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return this.returnType == returnType.getPrimitiveDataType(); + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return argumentTypes.clone(); + } + +} diff --git a/src/sqlancer/presto/ast/PrestoDefaultFunction.java b/src/sqlancer/presto/ast/PrestoDefaultFunction.java new file mode 100644 index 000000000..99b7f7fde --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoDefaultFunction.java @@ -0,0 +1,232 @@ +package sqlancer.presto.ast; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import sqlancer.Randomly; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoCompositeDataType; +import sqlancer.presto.PrestoSchema.PrestoDataType; +import sqlancer.presto.gen.PrestoTypedExpressionGenerator; + +public enum PrestoDefaultFunction implements PrestoFunction { + + // Conditional functions + IF_TRUE("if", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoSchema.PrestoCompositeDataType returnType) { + return true; + } + + @Override + public PrestoSchema.PrestoDataType[] getArgumentTypes(PrestoSchema.PrestoCompositeDataType returnType) { + return new PrestoSchema.PrestoDataType[] { PrestoSchema.PrestoDataType.BOOLEAN, + returnType.getPrimitiveDataType() }; + } + }, + + IF_TRUE_FALSE("if", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoSchema.PrestoCompositeDataType returnType) { + return true; + } + + @Override + public PrestoSchema.PrestoDataType[] getArgumentTypes(PrestoSchema.PrestoCompositeDataType returnType) { + return new PrestoSchema.PrestoDataType[] { PrestoSchema.PrestoDataType.BOOLEAN, + returnType.getPrimitiveDataType(), returnType.getPrimitiveDataType() }; + } + }, + + NULLIF("nullif", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoSchema.PrestoCompositeDataType returnType) { + return true; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return new PrestoDataType[] { returnType.getPrimitiveDataType(), returnType.getPrimitiveDataType() }; + } + }, + + COALESCE("coalesce", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoSchema.PrestoCompositeDataType returnType) { + return true; + } + + @Override + public int getNumberOfArguments() { + return UNLIMITED_NO_OF_ARGUMENTS; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + ArrayList prestoDataTypes = new ArrayList<>(); + long no = Randomly.getNotCachedInteger(2, 10); + for (int i = 0; i < no; i++) { + prestoDataTypes.add(returnType.getPrimitiveDataType()); + } + return prestoDataTypes.toArray(new PrestoDataType[0]); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoDataType[] argumentTypes, PrestoCompositeDataType returnType) { + return super.getArgumentsForReturnType(gen, depth, argumentTypes, returnType); + } + }, + + // comparison + + // Returns the largest of the provided values. → [same as input] + GREATEST("greatest", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoSchema.PrestoCompositeDataType returnType) { + return PrestoDataType.getOrderableTypes().contains(returnType.getPrimitiveDataType()); + } + + @Override + public int getNumberOfArguments() { + return UNLIMITED_NO_OF_ARGUMENTS; + } + + @Override + public PrestoSchema.PrestoDataType[] getArgumentTypes(PrestoSchema.PrestoCompositeDataType returnType) { + return new PrestoSchema.PrestoDataType[] { returnType.getPrimitiveDataType() }; + } + }, + // Returns the smallest of the provided values. → [same as input] + LEAST("least", null) { + @Override + public boolean isCompatibleWithReturnType(PrestoSchema.PrestoCompositeDataType returnType) { + return PrestoDataType.getOrderableTypes().contains(returnType.getPrimitiveDataType()); + } + + @Override + public int getNumberOfArguments() { + return UNLIMITED_NO_OF_ARGUMENTS; + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + ArrayList prestoDataTypes = new ArrayList<>(); + long no = Randomly.getNotCachedInteger(2, 10); + for (int i = 0; i < no; i++) { + prestoDataTypes.add(returnType.getPrimitiveDataType()); + } + return prestoDataTypes.toArray(new PrestoDataType[0]); + } + }; + + private static final int UNLIMITED_NO_OF_ARGUMENTS = -1; + private final PrestoDataType returnType; + private final PrestoDataType[] argumentTypes; + private final String functionName; + + PrestoDefaultFunction(String functionName, PrestoDataType returnType) { + this.functionName = functionName; + this.returnType = returnType; + this.argumentTypes = new PrestoDataType[0]; + } + + PrestoDefaultFunction(PrestoDataType returnType) { + this.returnType = returnType; + this.argumentTypes = new PrestoDataType[0]; + this.functionName = toString(); + } + + PrestoDefaultFunction(PrestoDataType returnType, PrestoDataType... argumentTypes) { + this.returnType = returnType; + this.argumentTypes = argumentTypes.clone(); + this.functionName = toString(); + } + + PrestoDefaultFunction(String functionName, PrestoDataType returnType, PrestoDataType... argumentTypes) { + this.functionName = functionName; + this.returnType = returnType; + this.argumentTypes = argumentTypes.clone(); + } + + public static List getFunctionsCompatibleWith(PrestoCompositeDataType returnType) { + return Stream.of(values()).filter(f -> f.isCompatibleWithReturnType(returnType)).collect(Collectors.toList()); + } + + @Override + public String getFunctionName() { + return functionName; + } + + @Override + public int getNumberOfArguments() { + return argumentTypes == null ? 0 : argumentTypes.length; + } + + @Override + public boolean isCompatibleWithReturnType(PrestoCompositeDataType returnType) { + return this.returnType == returnType.getPrimitiveDataType(); + } + + @Override + public PrestoDataType[] getArgumentTypes(PrestoCompositeDataType returnType) { + return argumentTypes.clone(); + } + + @Override + public List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoDataType[] argumentTypes, PrestoCompositeDataType returnType) { + List arguments = new ArrayList<>(); + + // This is a workaround based on the assumption that array types should refer to the same element type. + PrestoCompositeDataType savedArrayType = null; + if (returnType.getPrimitiveDataType() == PrestoDataType.ARRAY) { + savedArrayType = returnType; + } + + if (getNumberOfArguments() == UNLIMITED_NO_OF_ARGUMENTS) { + PrestoDataType dataType = getArgumentTypes(returnType)[0]; + // TODO: consider upper + long no = Randomly.getNotCachedInteger(2, 10); + for (int i = 0; i < no; i++) { + PrestoCompositeDataType type; + + if (dataType == PrestoDataType.ARRAY) { + if (savedArrayType == null) { + savedArrayType = dataType.get(); + } + type = savedArrayType; + } else { + type = PrestoCompositeDataType.fromDataType(dataType); + } + arguments.add(gen.generateExpression(type, depth + 1)); + } + } else { + for (PrestoDataType arg : argumentTypes) { + PrestoCompositeDataType type; + if (arg == PrestoDataType.ARRAY) { + if (savedArrayType == null) { + savedArrayType = arg.get(); + } + type = savedArrayType; + } else { + type = PrestoCompositeDataType.fromDataType(arg); + } + arguments.add(gen.generateExpression(type, depth + 1)); + + } + } + return arguments; + } + + @Override + public String toString() { + if (functionName != null) { + return functionName; + } + return super.toString(); + } + +} diff --git a/src/sqlancer/presto/ast/PrestoExpression.java b/src/sqlancer/presto/ast/PrestoExpression.java new file mode 100644 index 000000000..52e0ed784 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoExpression.java @@ -0,0 +1,8 @@ +package sqlancer.presto.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.presto.PrestoSchema.PrestoColumn; + +public interface PrestoExpression extends Expression { + +} diff --git a/src/sqlancer/presto/ast/PrestoFunction.java b/src/sqlancer/presto/ast/PrestoFunction.java new file mode 100644 index 000000000..fc7ef5b79 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoFunction.java @@ -0,0 +1,126 @@ +package sqlancer.presto.ast; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.gen.PrestoTypedExpressionGenerator; + +public interface PrestoFunction extends PrestoExpression { + + String getFunctionName(); + + boolean isCompatibleWithReturnType(PrestoSchema.PrestoCompositeDataType returnType); + + PrestoSchema.PrestoDataType[] getArgumentTypes(PrestoSchema.PrestoCompositeDataType returnType); + + default List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoSchema.PrestoDataType[] argumentTypes, PrestoSchema.PrestoCompositeDataType returnType) { + + List arguments = new ArrayList<>(); + + // This is a workaround based on the assumption that array types should refer to + // the same element type. + PrestoSchema.PrestoCompositeDataType savedArrayType = null; + if (returnType.getPrimitiveDataType() == PrestoSchema.PrestoDataType.ARRAY) { + savedArrayType = returnType; + } + // -1 - unlimited number of arguments + if (getNumberOfArguments() == -1) { + PrestoSchema.PrestoDataType dataType = argumentTypes[0]; + // TODO: consider upper + long no = Randomly.getNotCachedInteger(2, 10); + for (int i = 0; i < no; i++) { + PrestoSchema.PrestoCompositeDataType type; + + if (dataType == PrestoSchema.PrestoDataType.ARRAY) { + if (savedArrayType == null) { + savedArrayType = dataType.get(); + } + type = savedArrayType; + } else { + type = PrestoSchema.PrestoCompositeDataType.fromDataType(dataType); + } + arguments.add(gen.generateExpression(type, depth + 1)); + } + } else { + for (PrestoSchema.PrestoDataType arg : argumentTypes) { + PrestoSchema.PrestoCompositeDataType dataType; + if (arg == PrestoSchema.PrestoDataType.ARRAY) { + if (savedArrayType == null) { + savedArrayType = arg.get(); + } + dataType = savedArrayType; + } else { + dataType = PrestoSchema.PrestoCompositeDataType.fromDataType(arg); + } + PrestoExpression expression = gen.generateExpression(dataType, depth + 1); + arguments.add(expression); + } + } + return arguments; + } + + default List getArgumentsForReturnType(PrestoTypedExpressionGenerator gen, int depth, + PrestoSchema.PrestoCompositeDataType returnType, boolean orderable) { + + List arguments = new ArrayList<>(); + + // This is a workaround based on the assumption that array types should refer to + // the same element type. + PrestoSchema.PrestoCompositeDataType savedArrayType = null; + if (returnType.getPrimitiveDataType() == PrestoSchema.PrestoDataType.ARRAY) { + savedArrayType = returnType; + } + if (getNumberOfArguments() == -1) { + PrestoSchema.PrestoDataType dataType = getArgumentTypes(returnType)[0]; + // TODO: consider upper + long no = Randomly.getNotCachedInteger(2, 10); + for (int i = 0; i < no; i++) { + PrestoSchema.PrestoCompositeDataType compositeDataType; + if (dataType == PrestoSchema.PrestoDataType.ARRAY) { + if (savedArrayType == null) { + savedArrayType = dataType.get(); + } + compositeDataType = savedArrayType; + } else { + compositeDataType = PrestoSchema.PrestoCompositeDataType.fromDataType(dataType); + } + arguments.add(gen.generateExpression(compositeDataType, depth + 1)); + } + } else { + for (PrestoSchema.PrestoDataType dataType : getArgumentTypes(returnType)) { + PrestoSchema.PrestoCompositeDataType compositeDataType; + if (dataType == PrestoSchema.PrestoDataType.ARRAY) { + if (savedArrayType == null) { + PrestoSchema.PrestoCompositeDataType arrayType; + do { + arrayType = dataType.get(); + } while (!arrayType.getElementType().isOrderable()); + savedArrayType = arrayType; + } + compositeDataType = savedArrayType; + } else { + compositeDataType = PrestoSchema.PrestoCompositeDataType.fromDataType(dataType); + } + PrestoExpression expression = gen.generateExpression(compositeDataType, depth + 1); + arguments.add(expression); + } + } + return arguments; + } + + default int getNumberOfArguments() { + return getArgumentTypes(null).length; + } + + default boolean shouldPreserveOrderOfArguments() { + return false; + } + + default boolean isStandardFunction() { + return true; + } + +} diff --git a/src/sqlancer/presto/ast/PrestoFunctionNode.java b/src/sqlancer/presto/ast/PrestoFunctionNode.java new file mode 100644 index 000000000..14409a825 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoFunctionNode.java @@ -0,0 +1,11 @@ +package sqlancer.presto.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class PrestoFunctionNode extends NewFunctionNode implements PrestoExpression { + public PrestoFunctionNode(List args, F func) { + super(args, func); + } +} diff --git a/src/sqlancer/presto/ast/PrestoFunctionWithoutParenthesis.java b/src/sqlancer/presto/ast/PrestoFunctionWithoutParenthesis.java new file mode 100644 index 000000000..88007053d --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoFunctionWithoutParenthesis.java @@ -0,0 +1,23 @@ +package sqlancer.presto.ast; + +import sqlancer.presto.PrestoSchema; + +public class PrestoFunctionWithoutParenthesis implements PrestoExpression { + + private final PrestoSchema.PrestoCompositeDataType type; + private final String expr; + + public PrestoFunctionWithoutParenthesis(String expr, PrestoSchema.PrestoCompositeDataType type) { + this.expr = expr; + this.type = type; + } + + public String getExpr() { + return expr; + } + + public PrestoSchema.PrestoCompositeDataType getType() { + return type; + } + +} diff --git a/src/sqlancer/presto/ast/PrestoInOperation.java b/src/sqlancer/presto/ast/PrestoInOperation.java new file mode 100644 index 000000000..95edd7fc3 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoInOperation.java @@ -0,0 +1,11 @@ +package sqlancer.presto.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; + +public class PrestoInOperation extends NewInOperatorNode implements PrestoExpression { + public PrestoInOperation(PrestoExpression left, List right, boolean isNegated) { + super(left, right, isNegated); + } +} diff --git a/src/sqlancer/presto/ast/PrestoJoin.java b/src/sqlancer/presto/ast/PrestoJoin.java new file mode 100644 index 000000000..3fe82fc37 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoJoin.java @@ -0,0 +1,120 @@ +package sqlancer.presto.ast; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoColumn; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.gen.PrestoTypedExpressionGenerator; + +public class PrestoJoin implements PrestoExpression, Join { + + private final PrestoTableReference leftTable; + private final PrestoTableReference rightTable; + private final JoinType joinType; + private PrestoExpression onCondition; + private OuterType outerType; + + public PrestoJoin(PrestoTableReference leftTable, PrestoTableReference rightTable, JoinType joinType, + PrestoExpression whereCondition) { + this.leftTable = leftTable; + this.rightTable = rightTable; + this.joinType = joinType; + this.onCondition = whereCondition; + } + + public static List getJoins(List tableList, PrestoGlobalState globalState) { + List joinExpressions = new ArrayList<>(); + while (tableList.size() >= 2 && Randomly.getBooleanWithRatherLowProbability()) { + PrestoTableReference leftTable = tableList.remove(0); + PrestoTableReference rightTable = tableList.remove(0); + List columns = new ArrayList<>(leftTable.getTable().getColumns()); + columns.addAll(rightTable.getTable().getColumns()); + PrestoTypedExpressionGenerator joinGen = new PrestoTypedExpressionGenerator(globalState) + .setColumns(columns); + switch (JoinType.getRandom()) { + case INNER: + joinExpressions.add(PrestoJoin.createInnerJoin(leftTable, rightTable, joinGen.generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.BOOLEAN)))); + break; + case LEFT: + joinExpressions.add(PrestoJoin.createLeftOuterJoin(leftTable, rightTable, joinGen.generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.BOOLEAN)))); + break; + case RIGHT: + joinExpressions.add(PrestoJoin.createRightOuterJoin(leftTable, rightTable, joinGen.generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.BOOLEAN)))); + break; + default: + throw new AssertionError(); + } + } + return joinExpressions; + } + + public static PrestoJoin createRightOuterJoin(PrestoTableReference left, PrestoTableReference right, + PrestoExpression predicate) { + return new PrestoJoin(left, right, JoinType.RIGHT, predicate); + } + + public static PrestoJoin createLeftOuterJoin(PrestoTableReference left, PrestoTableReference right, + PrestoExpression predicate) { + return new PrestoJoin(left, right, JoinType.LEFT, predicate); + } + + public static PrestoJoin createInnerJoin(PrestoTableReference left, PrestoTableReference right, + PrestoExpression predicate) { + return new PrestoJoin(left, right, JoinType.INNER, predicate); + } + + public PrestoTableReference getLeftTable() { + return leftTable; + } + + public PrestoTableReference getRightTable() { + return rightTable; + } + + public JoinType getJoinType() { + return joinType; + } + + public PrestoExpression getOnCondition() { + return onCondition; + } + + public OuterType getOuterType() { + return outerType; + } + + @SuppressWarnings("unused") + private void setOuterType(OuterType outerType) { + this.outerType = outerType; + } + + public enum JoinType { + INNER, LEFT, RIGHT; + + public static JoinType getRandom() { + return Randomly.fromOptions(values()); + } + } + + public enum OuterType { + FULL, LEFT, RIGHT; + + public static OuterType getRandom() { + return Randomly.fromOptions(values()); + } + } + + @Override + public void setOnClause(PrestoExpression onClause) { + onCondition = onClause; + } + +} diff --git a/src/sqlancer/presto/ast/PrestoMultiValuedComparison.java b/src/sqlancer/presto/ast/PrestoMultiValuedComparison.java new file mode 100644 index 000000000..20ceaee79 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoMultiValuedComparison.java @@ -0,0 +1,37 @@ +package sqlancer.presto.ast; + +import java.util.ArrayList; +import java.util.List; + +public class PrestoMultiValuedComparison implements PrestoExpression { + + private final PrestoExpression left; + private final List right; + private final PrestoMultiValuedComparisonType type; + private final PrestoMultiValuedComparisonOperator op; + + public PrestoMultiValuedComparison(PrestoExpression left, List right, + PrestoMultiValuedComparisonType type, PrestoMultiValuedComparisonOperator op) { + this.left = left; + this.right = new ArrayList<>(right); + this.type = type; + this.op = op; + } + + public PrestoExpression getLeft() { + return left; + } + + public PrestoMultiValuedComparisonOperator getOp() { + return op; + } + + public List getRight() { + return new ArrayList<>(right); + } + + public PrestoMultiValuedComparisonType getType() { + return type; + } + +} diff --git a/src/sqlancer/presto/ast/PrestoMultiValuedComparisonOperator.java b/src/sqlancer/presto/ast/PrestoMultiValuedComparisonOperator.java new file mode 100644 index 000000000..829f554f4 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoMultiValuedComparisonOperator.java @@ -0,0 +1,43 @@ +package sqlancer.presto.ast; + +import sqlancer.Randomly; +import sqlancer.presto.PrestoSchema; + +public enum PrestoMultiValuedComparisonOperator { + EQUALS("="), NOT_EQUALS("<>"), NOT_EQUALS_ALT("!="), GREATER(">"), GREATER_EQUALS(">="), SMALLER("<"), + SMALLER_EQUALS("<="); + + private final String stringRepresentation; + + PrestoMultiValuedComparisonOperator(String stringRepresentation) { + this.stringRepresentation = stringRepresentation; + } + + public static PrestoMultiValuedComparisonOperator getRandom() { + return Randomly.fromOptions(values()); + } + + public static PrestoMultiValuedComparisonOperator getRandomForType(PrestoSchema.PrestoCompositeDataType type) { + PrestoSchema.PrestoDataType dataType = type.getPrimitiveDataType(); + + switch (dataType) { + case BOOLEAN: + case INT: + case FLOAT: + case DECIMAL: + case DATE: + case TIME: + case TIMESTAMP: + case TIME_WITH_TIME_ZONE: + case TIMESTAMP_WITH_TIME_ZONE: + return getRandom(); + default: + return Randomly.fromOptions(EQUALS, NOT_EQUALS, NOT_EQUALS_ALT); + } + } + + public String getStringRepresentation() { + return stringRepresentation; + } + +} diff --git a/src/sqlancer/presto/ast/PrestoMultiValuedComparisonType.java b/src/sqlancer/presto/ast/PrestoMultiValuedComparisonType.java new file mode 100644 index 000000000..d1ecdd382 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoMultiValuedComparisonType.java @@ -0,0 +1,11 @@ +package sqlancer.presto.ast; + +import sqlancer.Randomly; + +public enum PrestoMultiValuedComparisonType { + ANY, SOME, ALL; + + public static PrestoMultiValuedComparisonType getRandom() { + return Randomly.fromOptions(values()); + } +} diff --git a/src/sqlancer/presto/ast/PrestoPostfixText.java b/src/sqlancer/presto/ast/PrestoPostfixText.java new file mode 100644 index 000000000..de99fede4 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoPostfixText.java @@ -0,0 +1,9 @@ +package sqlancer.presto.ast; + +import sqlancer.common.ast.newast.NewPostfixTextNode; + +public class PrestoPostfixText extends NewPostfixTextNode implements PrestoExpression { + public PrestoPostfixText(PrestoExpression expr, String text) { + super(expr, text); + } +} diff --git a/src/sqlancer/presto/ast/PrestoQuantifiedComparison.java b/src/sqlancer/presto/ast/PrestoQuantifiedComparison.java new file mode 100644 index 000000000..3cbccc7b2 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoQuantifiedComparison.java @@ -0,0 +1,34 @@ +package sqlancer.presto.ast; + +public class PrestoQuantifiedComparison implements PrestoExpression { + + private final PrestoExpression left; + private final PrestoSelect right; + private final PrestoMultiValuedComparisonType type; + private final PrestoMultiValuedComparisonOperator op; + + public PrestoQuantifiedComparison(PrestoExpression left, PrestoSelect right, PrestoMultiValuedComparisonType type, + PrestoMultiValuedComparisonOperator op) { + this.left = left; + this.right = right; + this.type = type; + this.op = op; + } + + public PrestoExpression getLeft() { + return left; + } + + public PrestoMultiValuedComparisonOperator getOp() { + return op; + } + + public PrestoExpression getRight() { + return right; + } + + public PrestoMultiValuedComparisonType getType() { + return type; + } + +} diff --git a/src/sqlancer/presto/ast/PrestoSelect.java b/src/sqlancer/presto/ast/PrestoSelect.java new file mode 100644 index 000000000..f1eb50186 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoSelect.java @@ -0,0 +1,41 @@ +package sqlancer.presto.ast; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.presto.PrestoSchema.PrestoColumn; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.PrestoToStringVisitor; + +public class PrestoSelect extends SelectBase + implements PrestoExpression, Select { + + private boolean isDistinct; + + public boolean isDistinct() { + return isDistinct; + } + + public void setDistinct(boolean isDistinct) { + this.isDistinct = isDistinct; + } + + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (PrestoExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (PrestoJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return PrestoToStringVisitor.asString(this); + } +} diff --git a/src/sqlancer/presto/ast/PrestoTableReference.java b/src/sqlancer/presto/ast/PrestoTableReference.java new file mode 100644 index 000000000..7552ecc64 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoTableReference.java @@ -0,0 +1,12 @@ +package sqlancer.presto.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.presto.PrestoSchema; + +public class PrestoTableReference extends TableReferenceNode + implements PrestoExpression { + + public PrestoTableReference(PrestoSchema.PrestoTable table) { + super(table); + } +} diff --git a/src/sqlancer/presto/ast/PrestoTernary.java b/src/sqlancer/presto/ast/PrestoTernary.java new file mode 100644 index 000000000..618daa0fc --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoTernary.java @@ -0,0 +1,10 @@ +package sqlancer.presto.ast; + +import sqlancer.common.ast.newast.NewTernaryNode; + +public class PrestoTernary extends NewTernaryNode implements PrestoExpression { + public PrestoTernary(PrestoExpression left, PrestoExpression middle, PrestoExpression right, String leftStr, + String rightStr) { + super(left, middle, right, leftStr, rightStr); + } +} diff --git a/src/sqlancer/presto/ast/PrestoUnaryPostfixOperation.java b/src/sqlancer/presto/ast/PrestoUnaryPostfixOperation.java new file mode 100644 index 000000000..99677d3f8 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoUnaryPostfixOperation.java @@ -0,0 +1,52 @@ +package sqlancer.presto.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; +import sqlancer.presto.PrestoSchema; + +public class PrestoUnaryPostfixOperation extends NewUnaryPostfixOperatorNode + implements PrestoExpression { + + public PrestoUnaryPostfixOperation(PrestoExpression expr, PrestoUnaryPostfixOperator op) { + super(expr, op); + } + + public PrestoExpression getExpression() { + return getExpr(); + } + + public enum PrestoUnaryPostfixOperator implements BinaryOperatorNode.Operator { + IS_NULL("IS NULL") { + @Override + public PrestoSchema.PrestoDataType[] getInputDataTypes() { + return PrestoSchema.PrestoDataType.values(); + } + }, + IS_NOT_NULL("IS NOT NULL") { + @Override + public PrestoSchema.PrestoDataType[] getInputDataTypes() { + return PrestoSchema.PrestoDataType.values(); + } + }; + + private final String textRepresentations; + + PrestoUnaryPostfixOperator(String text) { + this.textRepresentations = text; + } + + public static PrestoUnaryPostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentations; + } + + public abstract PrestoSchema.PrestoDataType[] getInputDataTypes(); + + } + +} diff --git a/src/sqlancer/presto/ast/PrestoUnaryPrefixOperation.java b/src/sqlancer/presto/ast/PrestoUnaryPrefixOperation.java new file mode 100644 index 000000000..db7a08068 --- /dev/null +++ b/src/sqlancer/presto/ast/PrestoUnaryPrefixOperation.java @@ -0,0 +1,62 @@ +package sqlancer.presto.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; +import sqlancer.presto.PrestoSchema; + +public class PrestoUnaryPrefixOperation extends NewUnaryPrefixOperatorNode + implements PrestoExpression { + + public PrestoUnaryPrefixOperation(PrestoExpression expression, BinaryOperatorNode.Operator operation) { + super(expression, operation); + } + + public enum PrestoUnaryPrefixOperator implements BinaryOperatorNode.Operator { + NOT("NOT", PrestoSchema.PrestoDataType.BOOLEAN) { + @Override + public PrestoSchema.PrestoDataType getExpressionType() { + return PrestoSchema.PrestoDataType.BOOLEAN; + } + }, + + UNARY_PLUS("+", PrestoSchema.PrestoDataType.INT, PrestoSchema.PrestoDataType.FLOAT, + PrestoSchema.PrestoDataType.DECIMAL) { + @Override + public PrestoSchema.PrestoDataType getExpressionType() { + return PrestoSchema.PrestoDataType.INT; + } + }, + UNARY_MINUS("-", PrestoSchema.PrestoDataType.INT, PrestoSchema.PrestoDataType.FLOAT, + PrestoSchema.PrestoDataType.DECIMAL) { + @Override + public PrestoSchema.PrestoDataType getExpressionType() { + return PrestoSchema.PrestoDataType.INT; + } + }; + + private final String textRepresentation; + private final PrestoSchema.PrestoDataType[] dataTypes; + + PrestoUnaryPrefixOperator(String textRepresentation, PrestoSchema.PrestoDataType... dataTypes) { + this.textRepresentation = textRepresentation; + this.dataTypes = dataTypes.clone(); + } + + public PrestoSchema.PrestoDataType getRandomInputDataTypes() { + return Randomly.fromOptions(dataTypes); + } + + public abstract PrestoSchema.PrestoDataType getExpressionType(); + + @Override + public String getTextRepresentation() { + return this.textRepresentation; + } + + public PrestoSchema.PrestoDataType getExpressionType(PrestoSchema.PrestoDataType type) { + return type; + } + } + +} diff --git a/src/sqlancer/presto/gen/PrestoAlterTableGenerator.java b/src/sqlancer/presto/gen/PrestoAlterTableGenerator.java new file mode 100644 index 000000000..f3d5f9d55 --- /dev/null +++ b/src/sqlancer/presto/gen/PrestoAlterTableGenerator.java @@ -0,0 +1,67 @@ +package sqlancer.presto.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema.PrestoCompositeDataType; +import sqlancer.presto.PrestoSchema.PrestoTable; + +public final class PrestoAlterTableGenerator { + + private PrestoAlterTableGenerator() { + } + + public static SQLQueryAdapter getQuery(PrestoGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder("ALTER TABLE "); + PrestoTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + // PrestoTypedExpressionGenerator gen = new + // PrestoTypedExpressionGenerator(globalState).setColumns(table.getColumns()); + sb.append(table.getName()); + sb.append(" "); + Action action = Randomly.fromOptions(Action.values()); + switch (action) { + case ADD_COLUMN: + sb.append("ADD COLUMN "); + String columnName = table.getFreeColumnName(); + sb.append(columnName); + sb.append(" "); + sb.append(PrestoCompositeDataType.getRandomWithoutNull()); + break; + case ALTER_COLUMN: + sb.append("ALTER COLUMN "); + sb.append(table.getRandomColumn().getName()); + sb.append(" SET DATA TYPE "); + sb.append(PrestoCompositeDataType.getRandomWithoutNull()); + // if (Randomly.getBoolean()) { + // sb.append(" USING "); + // PrestoErrors.addExpressionErrors(errors); + // sb.append(PrestoToStringVisitor.asString(gen.generateExpression())); + // } + errors.add("Cannot change the type of this column: an index depends on it!"); + errors.add("Cannot change the type of a column that has a UNIQUE or PRIMARY KEY constraint specified"); + errors.add("Unimplemented type for cast"); + errors.add("Conversion:"); + errors.add("Cannot change the type of a column that has a CHECK constraint specified"); + break; + case DROP_COLUMN: + sb.append("DROP COLUMN "); + sb.append(table.getRandomColumn().getName()); + errors.add("named in key does not exist"); // TODO + errors.add("Cannot drop this column:"); + errors.add("Cannot drop column: table only has one column remaining!"); + errors.add("because there is a CHECK constraint that depends on it"); + errors.add("because there is a UNIQUE constraint that depends on it"); + break; + default: + throw new AssertionError(action); + } + return new SQLQueryAdapter(sb.toString(), errors, true, false); + } + + enum Action { + ADD_COLUMN, ALTER_COLUMN, DROP_COLUMN + } + +} diff --git a/src/sqlancer/presto/gen/PrestoDeleteGenerator.java b/src/sqlancer/presto/gen/PrestoDeleteGenerator.java new file mode 100644 index 000000000..9f869c241 --- /dev/null +++ b/src/sqlancer/presto/gen/PrestoDeleteGenerator.java @@ -0,0 +1,32 @@ +package sqlancer.presto.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.PrestoToStringVisitor; + +public final class PrestoDeleteGenerator { + + private PrestoDeleteGenerator() { + } + + public static SQLQueryAdapter generate(PrestoGlobalState globalState) { + StringBuilder sb = new StringBuilder("DELETE FROM "); + ExpectedErrors errors = new ExpectedErrors(); + PrestoTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + sb.append(PrestoToStringVisitor + .asString(new PrestoTypedExpressionGenerator(globalState).setColumns(table.getColumns()) + .generateExpression(PrestoSchema.PrestoCompositeDataType.getRandomWithoutNull()))); + } + PrestoErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, false, false); + } + +} diff --git a/src/sqlancer/presto/gen/PrestoIndexGenerator.java b/src/sqlancer/presto/gen/PrestoIndexGenerator.java new file mode 100644 index 000000000..c76283ee4 --- /dev/null +++ b/src/sqlancer/presto/gen/PrestoIndexGenerator.java @@ -0,0 +1,56 @@ +package sqlancer.presto.gen; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoColumn; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.PrestoToStringVisitor; +import sqlancer.presto.ast.PrestoExpression; + +public final class PrestoIndexGenerator { + + private PrestoIndexGenerator() { + } + + public static SQLQueryAdapter getQuery(PrestoGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + sb.append("CREATE "); + if (Randomly.getBoolean()) { + errors.add("Cant create unique index, table contains duplicate data on indexed column(s)"); + sb.append("UNIQUE "); + } + sb.append("INDEX "); + sb.append(Randomly.fromOptions("i0", "i1", "i2", "i3", "i4")); // cannot query this information + sb.append(" ON "); + PrestoTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + sb.append("("); + List columns = table.getRandomNonEmptyColumnSubset(); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(columns.get(i).getName()); + sb.append(" "); + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(Randomly.fromOptions("ASC", "DESC")); + } + } + sb.append(")"); + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + PrestoExpression expr = new PrestoTypedExpressionGenerator(globalState).setColumns(table.getColumns()) + .generateExpression(PrestoSchema.PrestoCompositeDataType.getRandomWithoutNull()); + sb.append(PrestoToStringVisitor.asString(expr)); + } + errors.add("already exists!"); + return new SQLQueryAdapter(sb.toString(), errors, true, false); + } + +} diff --git a/src/sqlancer/presto/gen/PrestoInsertGenerator.java b/src/sqlancer/presto/gen/PrestoInsertGenerator.java new file mode 100644 index 000000000..072a22ae0 --- /dev/null +++ b/src/sqlancer/presto/gen/PrestoInsertGenerator.java @@ -0,0 +1,51 @@ +package sqlancer.presto.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.common.gen.AbstractInsertGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema.PrestoColumn; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.PrestoToStringVisitor; +import sqlancer.presto.ast.PrestoExpression; + +public class PrestoInsertGenerator extends AbstractInsertGenerator { + + private final PrestoGlobalState globalState; + + public PrestoInsertGenerator(PrestoGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter getQuery(PrestoGlobalState globalState) { + return new PrestoInsertGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + sb.append("INSERT INTO "); + PrestoTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyColumnSubset(); + sb.append(table.getName()); + sb.append("("); + sb.append(columns.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + sb.append(")"); + sb.append(" VALUES "); + insertColumns(columns); + ExpectedErrors errors = new ExpectedErrors(); + PrestoErrors.addInsertErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, false, false); + } + + @Override + protected void insertValue(PrestoColumn prestoColumn) { + PrestoExpression constant = new PrestoTypedExpressionGenerator(globalState) + .generateInsertConstant(prestoColumn.getType()); + sb.append(PrestoToStringVisitor.asString(constant)); + + } + +} diff --git a/src/sqlancer/presto/gen/PrestoRandomQuerySynthesizer.java b/src/sqlancer/presto/gen/PrestoRandomQuerySynthesizer.java new file mode 100644 index 000000000..5c3c0db82 --- /dev/null +++ b/src/sqlancer/presto/gen/PrestoRandomQuerySynthesizer.java @@ -0,0 +1,73 @@ +package sqlancer.presto.gen; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.PrestoSchema.PrestoTables; +import sqlancer.presto.ast.PrestoConstant; +import sqlancer.presto.ast.PrestoExpression; +import sqlancer.presto.ast.PrestoJoin; +import sqlancer.presto.ast.PrestoSelect; +import sqlancer.presto.ast.PrestoTableReference; + +public final class PrestoRandomQuerySynthesizer { + + private PrestoRandomQuerySynthesizer() { + } + + public static PrestoSelect generateSelect(PrestoGlobalState globalState, int nrColumns) { + PrestoTables targetTables = globalState.getSchema().getRandomTableNonEmptyTables(); + PrestoTypedExpressionGenerator gen = new PrestoTypedExpressionGenerator(globalState) + .setColumns(targetTables.getColumns()); + PrestoSelect select = new PrestoSelect(); + // TODO: distinct + // select.setDistinct(Randomly.getBoolean()); + // boolean allowAggregates = Randomly.getBooleanWithSmallProbability(); + List columns = new ArrayList<>(); + for (int i = 0; i < nrColumns; i++) { + // if (allowAggregates && Randomly.getBoolean()) { + PrestoExpression expression = gen + .generateExpression(PrestoSchema.PrestoCompositeDataType.getRandomWithoutNull()); + columns.add(expression); + // } else { + // columns.add(gen()); + // } + } + select.setFetchColumns(columns); + List tables = targetTables.getTables(); + List tableList = tables.stream().map(t -> new PrestoTableReference(t)) + .collect(Collectors.toList()); + List joins = PrestoJoin.getJoins(tableList, globalState).stream() + .collect(Collectors.toList()); + select.setJoinList(new ArrayList<>(joins)); + select.setFromList(new ArrayList<>(tableList)); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(PrestoSchema.PrestoCompositeDataType.getRandomWithoutNull())); + } + if (Randomly.getBoolean()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + if (Randomly.getBoolean()) { + select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + } + + if (Randomly.getBoolean()) { + select.setLimitClause(PrestoConstant.createIntConstant(Randomly.getNotCachedInteger(0, Integer.MAX_VALUE))); + } + // if (Randomly.getBoolean()) { + // select.setOffsetClause( + // PrestoConstant.createIntConstant(Randomly.getNotCachedInteger(0, + // Integer.MAX_VALUE))); + // } + if (Randomly.getBoolean()) { + select.setHavingClause(gen.generateHavingClause()); + } + return select; + } + +} diff --git a/src/sqlancer/presto/gen/PrestoTableGenerator.java b/src/sqlancer/presto/gen/PrestoTableGenerator.java new file mode 100644 index 000000000..1d7df2ee6 --- /dev/null +++ b/src/sqlancer/presto/gen/PrestoTableGenerator.java @@ -0,0 +1,69 @@ +package sqlancer.presto.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema.PrestoColumn; +import sqlancer.presto.PrestoSchema.PrestoCompositeDataType; + +public class PrestoTableGenerator { + + private static List getNewColumns() { + List columns = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + String columnName = String.format("c%d", i); + PrestoCompositeDataType columnType = PrestoCompositeDataType.getRandomWithoutNull(); + columns.add(new PrestoColumn(columnName, columnType, false, false)); + } + return columns; + } + + public SQLQueryAdapter getQuery(PrestoGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + String tableName = globalState.getSchema().getFreeTableName(); + sb.append("CREATE TABLE "); + String catalog = globalState.getDbmsSpecificOptions().catalog; + String schema = globalState.getDatabaseName(); + + sb.append(catalog).append("."); + sb.append(schema).append("."); + + sb.append(tableName); + sb.append("("); + List columns = getNewColumns(); + // TypedExpressionGenerator, PrestoColumn, PrestoCompositeDataType> + // typedExpressionGenerator = new PrestoTypedExpressionGenerator(globalState).setColumns(columns); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sb.append(", "); + } + PrestoColumn column = columns.get(i); + sb.append(column.getName()); + sb.append(" "); + sb.append(column.getType()); + // if (globalState.getDbmsSpecificOptions().testIndexes && Randomly.getBooleanWithRatherLowProbability()) { + // sb.append(" UNIQUE"); + // } + // if (globalState.getDbmsSpecificOptions().testNotNullConstraints + // && Randomly.getBooleanWithRatherLowProbability()) { + // sb.append(" NOT NULL"); + // } + } + // if (globalState.getDbmsSpecificOptions().testIndexes && Randomly.getBoolean()) { + // errors.add("Invalid type for index"); + // List primaryKeyColumns = Randomly.nonEmptySubset(columns); + // sb.append(", PRIMARY KEY("); + // sb.append(primaryKeyColumns.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); + // sb.append(")"); + // } + sb.append(")"); + + return new SQLQueryAdapter(sb.toString(), errors, true, false); + } + +} diff --git a/src/sqlancer/presto/gen/PrestoTypedExpressionGenerator.java b/src/sqlancer/presto/gen/PrestoTypedExpressionGenerator.java new file mode 100644 index 000000000..4a788f01b --- /dev/null +++ b/src/sqlancer/presto/gen/PrestoTypedExpressionGenerator.java @@ -0,0 +1,887 @@ +package sqlancer.presto.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.gen.TypedExpressionGenerator; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoColumn; +import sqlancer.presto.PrestoSchema.PrestoCompositeDataType; +import sqlancer.presto.PrestoSchema.PrestoDataType; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.PrestoToStringVisitor; +import sqlancer.presto.ast.PrestoAggregateFunction; +import sqlancer.presto.ast.PrestoAtTimeZoneOperator; +import sqlancer.presto.ast.PrestoBetweenOperation; +import sqlancer.presto.ast.PrestoBinaryOperation; +import sqlancer.presto.ast.PrestoCaseOperation; +import sqlancer.presto.ast.PrestoCastFunction; +import sqlancer.presto.ast.PrestoColumnReference; +import sqlancer.presto.ast.PrestoConstant; +import sqlancer.presto.ast.PrestoDefaultFunction; +import sqlancer.presto.ast.PrestoExpression; +import sqlancer.presto.ast.PrestoFunctionNode; +import sqlancer.presto.ast.PrestoInOperation; +import sqlancer.presto.ast.PrestoJoin; +import sqlancer.presto.ast.PrestoMultiValuedComparison; +import sqlancer.presto.ast.PrestoMultiValuedComparisonOperator; +import sqlancer.presto.ast.PrestoMultiValuedComparisonType; +import sqlancer.presto.ast.PrestoPostfixText; +import sqlancer.presto.ast.PrestoQuantifiedComparison; +import sqlancer.presto.ast.PrestoSelect; +import sqlancer.presto.ast.PrestoTableReference; +import sqlancer.presto.ast.PrestoTernary; +import sqlancer.presto.ast.PrestoUnaryPostfixOperation; +import sqlancer.presto.ast.PrestoUnaryPrefixOperation; + +public final class PrestoTypedExpressionGenerator extends + TypedExpressionGenerator + implements NoRECGenerator, + TLPWhereGenerator { + + private final Randomly randomly; + private final PrestoGlobalState globalState; + private final int maxDepth; + private List tables; + + public PrestoTypedExpressionGenerator(PrestoGlobalState globalState) { + this.globalState = globalState; + this.randomly = globalState.getRandomly(); + this.maxDepth = globalState.getOptions().getMaxExpressionDepth(); + } + + @Override + public PrestoExpression generatePredicate() { + return generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.BOOLEAN), + randomly.getInteger(0, maxDepth)); + } + + @Override + public PrestoExpression negatePredicate(PrestoExpression predicate) { + return new PrestoUnaryPrefixOperation(predicate, PrestoUnaryPrefixOperation.PrestoUnaryPrefixOperator.NOT); + } + + @Override + public PrestoExpression isNull(PrestoExpression expr) { + return new PrestoUnaryPostfixOperation(expr, PrestoUnaryPostfixOperation.PrestoUnaryPostfixOperator.IS_NULL); + } + + @Override + public PrestoExpression generateConstant(PrestoSchema.PrestoCompositeDataType type) { + if (Objects.requireNonNull(type.getPrimitiveDataType()) == PrestoSchema.PrestoDataType.ARRAY) { + return PrestoConstant.createArrayConstant(type); + // case MAP: + // return PrestoConstant.createMapConstant(type); + } + return PrestoConstant.generateConstant(type, false); + } + + public PrestoExpression generateInsertConstant(PrestoSchema.PrestoCompositeDataType type) { + if (Objects.requireNonNull(type.getPrimitiveDataType()) == PrestoSchema.PrestoDataType.ARRAY) { + return PrestoConstant.createArrayConstant(type); + // case MAP: + // return PrestoConstant.createMapConstant(type); + } + return PrestoConstant.generateConstant(type, true); + } + + @Override + public PrestoExpression generateExpression(PrestoSchema.PrestoCompositeDataType type, int depth) { + if (allowAggregates && Randomly.getBoolean()) { + return generateAggregate(type); + } + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + return generateLeafNode(type); + } else { + // TODO: functions + List applicableFunctions = PrestoDefaultFunction.getFunctionsCompatibleWith(type); + if (Randomly.getBooleanWithRatherLowProbability() && !applicableFunctions.isEmpty()) { + PrestoDefaultFunction function = Randomly.fromList(applicableFunctions); + return generateFunction(type, depth, function); + } + // TODO: try + // if (Randomly.getBooleanWithRatherLowProbability()) { + // return generateTry(type, depth); + // } + + // TODO: cast + // + // if (Randomly.getBooleanWithRatherLowProbability()) { + // Node expressionNode = generateCast(type, depth); + // } + if (Randomly.getBooleanWithRatherLowProbability()) { + return getCase(type, depth); + } + switch (type.getPrimitiveDataType()) { + case BOOLEAN: + return generateBooleanExpression(depth); + case VARCHAR: + case CHAR: + return generateStringExpression(type, depth); + case INT: + case DECIMAL: + case FLOAT: + return generateNumericExpression(depth); + case DATE: + case TIME: + case TIMESTAMP: + case TIME_WITH_TIME_ZONE: + case TIMESTAMP_WITH_TIME_ZONE: + return generateTemporalExpression(type, depth); + case INTERVAL_YEAR_TO_MONTH: + case INTERVAL_DAY_TO_SECOND: + return generateIntervalExpression(type, depth); + case JSON: + return generateJsonExpression(type); + case VARBINARY: + case ARRAY: + // case MAP: + return generateLeafNode(type); // TODO + default: + throw new AssertionError(type); + } + } + } + + private PrestoExpression generateJsonExpression(PrestoSchema.PrestoCompositeDataType type) { + return generateLeafNode(type); + } + + private PrestoExpression generateCast(PrestoSchema.PrestoCompositeDataType type, int depth) { + // check can cast + PrestoExpression expressionNode = generateExpression(getRandomType(), depth + 1); + return new PrestoCastFunction(expressionNode, type); + } + + @SuppressWarnings("unused") + private PrestoExpression generateTry(PrestoSchema.PrestoCompositeDataType type, int depth) { + if (type.getPrimitiveDataType().isNumeric() && Randomly.getBooleanWithRatherLowProbability()) { + PrestoExpression expression = generateExpression(type); + return new PrestoFunctionNode<>(List.of(expression), "try"); + } + + List applicableFunctions = PrestoDefaultFunction.getFunctionsCompatibleWith(type); + if (Randomly.getBooleanWithRatherLowProbability() && !applicableFunctions.isEmpty()) { + PrestoDefaultFunction function = Randomly.fromList(applicableFunctions); + PrestoExpression expression = generateFunction(type, depth, function); + return new PrestoFunctionNode<>(List.of(expression), "try"); + } + return new PrestoFunctionNode<>(List.of(generateCast(type, depth)), "try"); + } + + private PrestoCaseOperation getCase(PrestoSchema.PrestoCompositeDataType type, int depth) { + List conditions = new ArrayList<>(); + List cases = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + conditions.add(generateExpression(type, depth + 1)); + cases.add(generateExpression(type, depth + 1)); + } + PrestoExpression elseExpr = null; + if (Randomly.getBoolean()) { + elseExpr = generateExpression(type, depth + 1); + } + PrestoExpression expression = generateExpression(type); + return new PrestoCaseOperation(expression, conditions, cases, elseExpr); + } + + private PrestoExpression generateFunction(PrestoSchema.PrestoCompositeDataType returnType, int depth, + PrestoDefaultFunction function) { + + PrestoSchema.PrestoDataType[] argumentTypes = function.getArgumentTypes(returnType); + List arguments = new ArrayList<>(); + + // This is a workaround based on the assumption that array types should refer to + // the same element type. + PrestoSchema.PrestoCompositeDataType savedArrayType = null; + if (returnType.getPrimitiveDataType() == PrestoSchema.PrestoDataType.ARRAY) { + savedArrayType = returnType; + } + if (function.getNumberOfArguments() == -1) { + PrestoSchema.PrestoDataType dataType = argumentTypes[0]; + // TODO: consider upper + long no = Randomly.getNotCachedInteger(2, 10); + for (int i = 0; i < no; i++) { + PrestoSchema.PrestoCompositeDataType type; + + if (dataType == PrestoSchema.PrestoDataType.ARRAY) { + if (savedArrayType == null) { + savedArrayType = dataType.get(); + } + type = savedArrayType; + } else { + type = PrestoSchema.PrestoCompositeDataType.fromDataType(dataType); + } + arguments.add(generateExpression(type, depth + 1)); + } + } else { + for (PrestoSchema.PrestoDataType arg : argumentTypes) { + PrestoSchema.PrestoCompositeDataType dataType; + if (arg == PrestoSchema.PrestoDataType.ARRAY) { + if (savedArrayType == null) { + savedArrayType = arg.get(); + } + dataType = savedArrayType; + } else { + dataType = PrestoSchema.PrestoCompositeDataType.fromDataType(arg); + } + PrestoExpression expression = generateExpression(dataType, depth + 1); + arguments.add(expression); + } + } + return new PrestoFunctionNode<>(arguments, function); + } + + private PrestoExpression generateStringExpression(PrestoSchema.PrestoCompositeDataType type, int depth) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + return generateLeafNode(type); + } + return getStringOperation(depth); + } + + private PrestoBinaryOperation getStringOperation(int depth) { + StringExpression exprType = Randomly.fromOptions(StringExpression.values()); + if (Objects.requireNonNull(exprType) == StringExpression.CONCAT) { + PrestoExpression left = generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.VARCHAR), depth + 1); + PrestoExpression right = generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.VARCHAR), depth + 1); + PrestBinaryStringOperator operator = PrestBinaryStringOperator.CONCAT; + return new PrestoBinaryOperation(left, right, operator); + } + throw new AssertionError(exprType); + } + + private PrestoExpression generateBooleanExpression(int depth) { + List booleanExpressions = Arrays.stream(BooleanExpression.values()) + .collect(Collectors.toList()); + if (!globalState.getDbmsSpecificOptions().testBetween) { + booleanExpressions.remove(BooleanExpression.BETWEEN); + } + + booleanExpressions.remove(BooleanExpression.REGEX); + + BooleanExpression exprType = Randomly.fromList(booleanExpressions); + switch (exprType) { + case NOT: + return generateNOT(depth + 1); + case BINARY_COMPARISON: + return getBinaryComparison(depth); + case BINARY_LOGICAL: + return getBinaryLogical(depth); + case AND_OR_CHAIN: + return getAndOrChain(depth); + case REGEX: + return getRegex(depth); + case IS_NULL: + return new PrestoUnaryPostfixOperation(generateExpression(getRandomType(), depth + 1), + Randomly.fromOptions(PrestoUnaryPostfixOperation.PrestoUnaryPostfixOperator.IS_NULL, + PrestoUnaryPostfixOperation.PrestoUnaryPostfixOperator.IS_NOT_NULL)); + case IN: + return getInOperation(depth); + case BETWEEN: + return getBetween(depth); + case LIKE: + return getLike(depth); + case MULTI_VALUED_COMPARISON: // TODO other operators + return getMultiValuedComparison(depth); + default: + throw new AssertionError(exprType); + } + } + + private PrestoExpression getMultiValuedComparison(int depth) { + + PrestoSchema.PrestoCompositeDataType type; + do { + type = PrestoSchema.PrestoCompositeDataType + .fromDataType(Randomly.fromList(PrestoSchema.PrestoDataType.getOrderableTypes())); + } while (type.getPrimitiveDataType() == PrestoSchema.PrestoDataType.ARRAY + && !type.getElementType().getPrimitiveDataType().isOrderable()); + + PrestoMultiValuedComparisonType comparisonType = PrestoMultiValuedComparisonType.getRandom(); + PrestoMultiValuedComparisonOperator comparisonOperator = PrestoMultiValuedComparisonOperator + .getRandomForType(type); + PrestoExpression left = generateExpression(type, depth + 1); + // sub-query + PrestoSchema.PrestoCompositeDataType finalType = type; + List columnsOfType = columns.stream().filter(c -> c.getType() == finalType) + .collect(Collectors.toList()); + if (Randomly.getBooleanWithRatherLowProbability() && !columnsOfType.isEmpty()) { + PrestoSchema.PrestoColumn column = Randomly.fromList(columnsOfType); + PrestoSelect subquery = generateSubquery(List.of(column)); + return new PrestoQuantifiedComparison(left, subquery, comparisonType, comparisonOperator); + } + int nr = Randomly.smallNumber() + 2; + List rightList = new ArrayList<>(); + for (int i = 0; i < nr; i++) { + rightList.add(generateConstant(type)); + } + return new PrestoMultiValuedComparison(left, rightList, comparisonType, comparisonOperator); + } + + private PrestoSelect generateSubquery(List columns) { + PrestoSelect select = new PrestoSelect(); + List allColumns = columns.stream().map((c) -> new PrestoColumnReference(c)) + .collect(Collectors.toList()); + select.setFetchColumns(allColumns); + List tables = columns.stream().map(AbstractTableColumn::getTable) + .collect(Collectors.toList()); + List tableList = tables.stream().map(t -> new PrestoTableReference(t)).distinct() + .collect(Collectors.toList()); + List tableNodeList = tables.stream().map(t -> new PrestoTableReference(t)) + .collect(Collectors.toList()); + select.setFromList(tableNodeList); + TypedExpressionGenerator typedExpressionGenerator = new PrestoTypedExpressionGenerator( + globalState).setColumns(columns); + PrestoExpression predicate = typedExpressionGenerator.generatePredicate(); + select.setWhereClause(predicate); + if (Randomly.getBooleanWithSmallProbability()) { + select.setOrderByClauses(typedExpressionGenerator.generateOrderBys()); + } + List joins = PrestoJoin.getJoins(tableList, globalState).stream() + .collect(Collectors.toList()); + select.setJoinList(joins); + return select; + } + + private PrestoExpression generateNumericExpression(int depth) { + PrestoSchema.PrestoDataType dataType = Randomly.fromList(PrestoSchema.PrestoDataType.getNumberTypes()); + PrestoSchema.PrestoCompositeDataType type = PrestoSchema.PrestoCompositeDataType.fromDataType(dataType); + if (Randomly.getBoolean()) { + BinaryOperatorNode.Operator operator = PrestoBinaryArithmeticOperator.getRandom(); + PrestoExpression left = generateExpression(type, depth); + PrestoExpression right = generateExpression(type, depth); + return new PrestoBinaryOperation(left, right, operator); + } else { + BinaryOperatorNode.Operator operator = PrestoUnaryArithmeticOperator.MINUS; + PrestoExpression left = generateExpression(type, depth); + return new PrestoUnaryPrefixOperation(left, operator); + } + } + + private PrestoExpression generateTemporalExpression(PrestoSchema.PrestoCompositeDataType type, int depth) { + if (Randomly.getBooleanWithSmallProbability()) { + PrestoExpression left = generateExpression(type, depth); + PrestoExpression right = generateExpression(PrestoSchema.PrestoCompositeDataType + .fromDataType(Randomly.fromList(PrestoSchema.PrestoDataType.getIntervalTypes())), depth); + BinaryOperatorNode.Operator operator = PrestoBinaryTemporalOperator.getRandom(); + return new PrestoBinaryOperation(left, right, operator); + } + + // timestamp at time zone + if (Randomly.getBooleanWithSmallProbability() + && (type.getPrimitiveDataType() == PrestoSchema.PrestoDataType.TIMESTAMP + || type.getPrimitiveDataType() == PrestoSchema.PrestoDataType.TIMESTAMP_WITH_TIME_ZONE)) { + return new PrestoAtTimeZoneOperator(generateExpression(type, depth + 1), + PrestoConstant.createTimezoneConstant()); + } + return generateLeafNode(type); + } + + private PrestoExpression generateIntervalExpression(PrestoSchema.PrestoCompositeDataType type, int depth) { + if (Randomly.getBooleanWithSmallProbability()) { + PrestoExpression left = generateExpression(type, depth); + + PrestoExpression right; + if (Randomly.getBoolean()) { + right = generateExpression(PrestoSchema.PrestoCompositeDataType + .fromDataType(Randomly.fromList(PrestoSchema.PrestoDataType.getTemporalTypes())), depth); + } else { + right = generateExpression(type, depth); + } + BinaryOperatorNode.Operator operator = PrestoBinaryTemporalOperator.getRandom(); + if (Randomly.getBoolean()) { + return new PrestoBinaryOperation(left, right, operator); + } else { + return new PrestoBinaryOperation(right, left, operator); + } + } + return generateLeafNode(type); + + // functions + + // timestamp at time zone + } + + private PrestoExpression getLike(int depth) { + PrestoSchema.PrestoCompositeDataType type = PrestoSchema.PrestoCompositeDataType + .fromDataType(PrestoSchema.PrestoDataType.VARCHAR); + PrestoExpression expression = generateExpression(type, depth + 1); + PrestoExpression pattern = generateExpression(type, depth + 1); + if (Randomly.getBoolean()) { + return new PrestoBinaryOperation(expression, pattern, PrestoLikeOperator.getRandom()); + } else { + String randomlyString = randomly.getString(); + String randomlyChar = randomly.getChar(); + PrestoExpression escape = new PrestoConstant.PrestoTextConstant(randomlyChar, 1); + int index = randomlyString.indexOf(randomlyChar); + while (index > -1) { + String wildcard = Randomly.fromOptions("%", "_"); + randomlyString = randomlyString.substring(0, index + 1) + wildcard + + randomlyString.substring(index + 1); + index = randomlyString.indexOf(randomlyChar, index + 1); + } + PrestoConstant.PrestoTextConstant patternString = new PrestoConstant.PrestoTextConstant(randomlyString); + return new PrestoTernary(expression, patternString, escape, "LIKE", "ESCAPE"); + } + } + + private PrestoBinaryOperation getRegex(int depth) { + PrestoExpression left = generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.VARCHAR), depth + 1); + PrestoExpression right = generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.VARCHAR), depth + 1); + return new PrestoBinaryOperation(left, right, PrestoBinaryLogicalOperator.getRandom()); + } + + private PrestoBinaryOperation getBinaryLogical(int depth) { + PrestoSchema.PrestoCompositeDataType type = PrestoSchema.PrestoCompositeDataType + .fromDataType(PrestoSchema.PrestoDataType.BOOLEAN); + PrestoExpression left = generateExpression(type, depth + 1); + PrestoExpression right = generateExpression(type, depth + 1); + BinaryOperatorNode.Operator operator = PrestoBinaryLogicalOperator.getRandom(); + return new PrestoBinaryOperation(left, right, operator); + } + + private PrestoExpression getBetween(int depth) { + PrestoSchema.PrestoCompositeDataType type = PrestoSchema.PrestoCompositeDataType + .fromDataType(Randomly.fromList(PrestoSchema.PrestoDataType.getNumericTypes())); + PrestoExpression expression = generateExpression(type, depth + 1); + PrestoExpression left = generateExpression(type, depth + 1); + PrestoExpression right = generateExpression(type, depth + 1); + return new PrestoBetweenOperation(expression, left, right, Randomly.getBoolean()); + } + + private PrestoExpression getInOperation(int depth) { + PrestoSchema.PrestoCompositeDataType type = PrestoSchema.PrestoCompositeDataType + .fromDataType(PrestoSchema.PrestoDataType.getRandomWithoutNull()); + PrestoExpression left = generateExpression(type, depth + 1); + List inList = generateExpressions(type, Randomly.smallNumber() + 1, depth + 1); + boolean isNegated = Randomly.getBoolean(); + return new PrestoInOperation(left, inList, isNegated); + } + + private PrestoExpression getAndOrChain(int depth) { + PrestoExpression left = generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.BOOLEAN), depth + 1); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + PrestoExpression right = generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.BOOLEAN), depth + 1); + BinaryOperatorNode.Operator operator = PrestoBinaryLogicalOperator.getRandom(); + left = new PrestoBinaryOperation(left, right, operator); + } + return left; + } + + private PrestoExpression getBinaryComparison(int depth) { + PrestoSchema.PrestoCompositeDataType type = getRandomType(); + BinaryOperatorNode.Operator op = PrestoBinaryComparisonOperator.getRandomForType(type); + PrestoExpression left = generateExpression(type, depth + 1); + PrestoExpression right = generateExpression(type, depth + 1); + return new PrestoBinaryOperation(left, right, op); + } + + private PrestoExpression generateNOT(int depth) { + PrestoUnaryPrefixOperation.PrestoUnaryPrefixOperator operator = PrestoUnaryPrefixOperation.PrestoUnaryPrefixOperator.NOT; + return new PrestoUnaryPrefixOperation( + generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.BOOLEAN), depth), + operator); + } + + @Override + protected PrestoExpression generateColumn(PrestoSchema.PrestoCompositeDataType type) { + List columnList = columns.stream() + .filter(c -> c.getType().getPrimitiveDataType() == type.getPrimitiveDataType()) + .collect(Collectors.toList()); + PrestoSchema.PrestoColumn column = Randomly.fromList(columnList); + return new PrestoColumnReference(column); + } + + @Override + public PrestoExpression generateLeafNode(PrestoSchema.PrestoCompositeDataType type) { + if (Randomly.getBoolean()) { + return generateConstant(type); + } else { + List columnList = filterColumns(type.getPrimitiveDataType()); + if (columnList.isEmpty()) { + return generateConstant(type); + } else { + return generateColumn(type); + } + } + } + + private List filterColumns(PrestoSchema.PrestoDataType dataType) { + if (columns == null) { + return Collections.emptyList(); + } else { + return columns.stream().filter(c -> c.getType().getPrimitiveDataType() == dataType) + .collect(Collectors.toList()); + } + } + + @Override + protected PrestoSchema.PrestoCompositeDataType getRandomType() { + return PrestoSchema.PrestoCompositeDataType.getRandomWithoutNull(); + } + + @Override + protected boolean canGenerateColumnOfType(PrestoSchema.PrestoCompositeDataType type) { + return columns.stream().anyMatch(c -> c.getType() == type); + } + + public PrestoExpression generateAggregate() { + PrestoAggregateFunction aggregateFunction = PrestoAggregateFunction.getRandom(); + List argsForAggregate = generateArgsForAggregate(aggregateFunction); + return new PrestoFunctionNode<>(argsForAggregate, aggregateFunction); + } + + public List generateArgsForAggregate(PrestoAggregateFunction aggregateFunction) { + PrestoSchema.PrestoCompositeDataType returnType; + do { + returnType = aggregateFunction.getCompositeReturnType(); + } while (!aggregateFunction.isCompatibleWithReturnType(returnType)); + return aggregateFunction.getArgumentsForReturnType(this, this.maxDepth - 1, returnType, false); + } + + private PrestoExpression generateAggregate(PrestoSchema.PrestoCompositeDataType type) { + PrestoAggregateFunction aggregateFunction = Randomly + .fromList(PrestoAggregateFunction.getFunctionsCompatibleWith(type)); + List argsForAggregate = generateArgsForAggregate(type, aggregateFunction); + return new PrestoFunctionNode<>(argsForAggregate, aggregateFunction); + } + + public List generateArgsForAggregate(PrestoSchema.PrestoCompositeDataType type, + PrestoAggregateFunction aggregateFunction) { + List returnTypes = aggregateFunction.getReturnTypes(type.getPrimitiveDataType()); + List arguments = new ArrayList<>(); + allowAggregates = false; // + for (PrestoSchema.PrestoDataType argumentType : returnTypes) { + arguments.add(generateExpression(PrestoSchema.PrestoCompositeDataType.fromDataType(argumentType))); + } + // return new NewFunctionNode<>(arguments, aggregateFunction); + return arguments; + } + + @Override + public List generateOrderBys() { + List expressions = new ArrayList<>(); + int nr = Randomly.smallNumber() + 1; + ArrayList prestoColumns = new ArrayList<>(columns); + prestoColumns.removeIf(c -> !c.isOrderable()); + for (int i = 0; i < nr && !prestoColumns.isEmpty(); i++) { + PrestoSchema.PrestoColumn randomColumn = Randomly.fromList(prestoColumns); + PrestoColumnReference columnReference = new PrestoColumnReference(randomColumn); + prestoColumns.remove(randomColumn); + expressions.add(columnReference); + } + return expressions; + } + + public PrestoExpression generateHavingClause() { + allowAggregates = true; + PrestoExpression expr = generateExpression(PrestoSchema.PrestoCompositeDataType.getRandomWithoutNull()); + allowAggregates = false; + return expr; + } + + public PrestoExpression generateExpressionWithColumns(List columns, int remainingDepth) { + if (columns.isEmpty() || remainingDepth <= 2 && Randomly.getBooleanWithRatherLowProbability()) { + return generateConstant(PrestoSchema.PrestoCompositeDataType.getRandomWithoutNull()); + } + PrestoSchema.PrestoColumn column = Randomly.fromList(columns); + if (remainingDepth <= 2 || Randomly.getBooleanWithRatherLowProbability()) { + return new PrestoColumnReference(column); + } + List possibleOptions = new ArrayList<>( + Arrays.asList(PrestoTypedExpressionGenerator.Expression.values())); + PrestoTypedExpressionGenerator.Expression expr = Randomly.fromList(possibleOptions); + BinaryOperatorNode.Operator op; + switch (expr) { + case BINARY_LOGICAL: + case BINARY_ARITHMETIC: + op = PrestoTypedExpressionGenerator.PrestoBinaryLogicalOperator.getRandom(); + break; + case BINARY_COMPARISON: + op = PrestoBinaryComparisonOperator.getRandom(); + break; + default: + throw new AssertionError(); + } + return new PrestoBinaryOperation(generateExpression(column.getType(), remainingDepth - 1), + generateExpression(column.getType(), remainingDepth - 1), op); + } + + private enum StringExpression { + CONCAT + } + + public enum PrestBinaryStringOperator implements BinaryOperatorNode.Operator { + CONCAT("||"); + + private final String textRepresentation; + + PrestBinaryStringOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static BinaryOperatorNode.Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + } + + public enum PrestoBinaryTemporalOperator implements BinaryOperatorNode.Operator { + ADD("+"), SUB("-"); + + private final String textRepresentation; + + PrestoBinaryTemporalOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static BinaryOperatorNode.Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + } + + private enum BooleanExpression { + NOT, BINARY_COMPARISON, BINARY_LOGICAL, AND_OR_CHAIN, REGEX, IS_NULL, IN, BETWEEN, LIKE, MULTI_VALUED_COMPARISON + } + + public enum PrestoBinaryLogicalOperator implements BinaryOperatorNode.Operator { + + AND, OR; + + public static BinaryOperatorNode.Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return toString(); + } + + } + + public enum PrestoLikeOperator implements BinaryOperatorNode.Operator { + LIKE("LIKE"), // + NOT_LIKE("NOT LIKE"); + + private final String textRepresentation; + + PrestoLikeOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static PrestoLikeOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + } + + public enum PrestoBinaryComparisonOperator implements BinaryOperatorNode.Operator { + EQUALS("="), NOT_EQUALS("<>"), NOT_EQUALS_ALT("!="), IS_DISTINCT_FROM("IS DISTINCT FROM"), + IS_NOT_DISTINCT_FROM("IS NOT DISTINCT FROM"), GREATER(">"), GREATER_EQUALS(">="), SMALLER("<"), + SMALLER_EQUALS("<="); + + private final String textRepresentation; + + PrestoBinaryComparisonOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static BinaryOperatorNode.Operator getRandom() { + return Randomly.fromOptions(values()); + } + + public static BinaryOperatorNode.Operator getRandomStringOperator() { + return Randomly.fromOptions(EQUALS, NOT_EQUALS, IS_DISTINCT_FROM, IS_NOT_DISTINCT_FROM); + } + + public static BinaryOperatorNode.Operator getRandomForType(PrestoSchema.PrestoCompositeDataType type) { + PrestoSchema.PrestoDataType dataType = type.getPrimitiveDataType(); + + switch (dataType) { + case BOOLEAN: + case INT: + case FLOAT: + case DECIMAL: + case DATE: + case TIME: + case TIMESTAMP: + case TIME_WITH_TIME_ZONE: + case TIMESTAMP_WITH_TIME_ZONE: + return getRandom(); + case VARCHAR: + case CHAR: + case VARBINARY: + case JSON: + case ARRAY: + case INTERVAL_YEAR_TO_MONTH: + case INTERVAL_DAY_TO_SECOND: + // return Randomly.fromOptions(EQUALS, NOT_EQUALS, NOT_EQUALS_ALT, + // IS_DISTINCT_FROM, + // IS_NOT_DISTINCT_FROM); + default: + return Randomly.fromOptions(EQUALS, NOT_EQUALS, NOT_EQUALS_ALT, IS_DISTINCT_FROM, IS_NOT_DISTINCT_FROM); + } + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + } + + public enum PrestoBinaryArithmeticOperator implements BinaryOperatorNode.Operator { + ADD("+"), SUB("-"), MULT("*"), DIV("/"), MOD("%"); + + private final String textRepresentation; + + PrestoBinaryArithmeticOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static BinaryOperatorNode.Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + } + + public enum PrestoUnaryArithmeticOperator implements BinaryOperatorNode.Operator { + MINUS("-"); + + private final String textRepresentation; + + PrestoUnaryArithmeticOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + } + + private enum Expression { + BINARY_LOGICAL, BINARY_COMPARISON, BINARY_ARITHMETIC + } + + @Override + public PrestoTypedExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public PrestoExpression generateBooleanExpression() { + return generateExpression( + PrestoSchema.PrestoCompositeDataType.fromDataType(PrestoSchema.PrestoDataType.BOOLEAN), + randomly.getInteger(0, maxDepth)); + } + + @Override + public PrestoSelect generateSelect() { + return new PrestoSelect(); + } + + @Override + public List getRandomJoinClauses() { + List tableList = tables.stream().map(t -> new PrestoTableReference(t)) + .collect(Collectors.toList()); + List joins = PrestoJoin.getJoins(tableList, globalState); + tables = tableList.stream().map(t -> t.getTable()).collect(Collectors.toList()); + return joins; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new PrestoTableReference(t)).collect(Collectors.toList()); + } + + @Override + public String generateOptimizedQueryString(PrestoSelect select, PrestoExpression whereCondition, + boolean shouldUseAggregate) { + if (shouldUseAggregate) { + PrestoFunctionNode aggr = new PrestoFunctionNode<>( + List.of(new PrestoColumnReference(new PrestoColumn("*", + new PrestoCompositeDataType(PrestoDataType.INT, 0, 0), false, false))), + PrestoAggregateFunction.COUNT); + select.setFetchColumns(List.of(aggr)); + + } else { + List allColumns = columns.stream().map((c) -> new PrestoColumnReference(c)) + .collect(Collectors.toList()); + select.setFetchColumns(allColumns); + if (Randomly.getBooleanWithSmallProbability()) { + select.setOrderByClauses(generateOrderBys()); + } + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(PrestoSelect select, PrestoExpression whereCondition) { + PrestoExpression asText = new PrestoPostfixText( + + new PrestoCastFunction( + new PrestoPostfixText(whereCondition, + " IS NOT NULL AND " + PrestoToStringVisitor.asString(whereCondition)), + new PrestoCompositeDataType(PrestoDataType.INT, 8, 0)), + "as count"); + + select.setFetchColumns(List.of(asText)); + select.setWhereClause(null); + return "SELECT SUM(count) FROM (" + PrestoToStringVisitor.asString(select) + ") as res"; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (Randomly.getBoolean()) { + return List.of(new PrestoColumnReference(new PrestoColumn("*", null, false, false))); + } + return Randomly.nonEmptySubset(columns).stream().map(c -> new PrestoColumnReference(c)) + .collect(Collectors.toList()); + } +} diff --git a/src/sqlancer/presto/gen/PrestoUpdateGenerator.java b/src/sqlancer/presto/gen/PrestoUpdateGenerator.java new file mode 100644 index 000000000..a8afcc578 --- /dev/null +++ b/src/sqlancer/presto/gen/PrestoUpdateGenerator.java @@ -0,0 +1,52 @@ +package sqlancer.presto.gen; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.gen.AbstractUpdateGenerator; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema.PrestoColumn; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.PrestoToStringVisitor; +import sqlancer.presto.ast.PrestoExpression; + +public final class PrestoUpdateGenerator extends AbstractUpdateGenerator { + + private final PrestoGlobalState globalState; + private PrestoTypedExpressionGenerator gen; + + private PrestoUpdateGenerator(PrestoGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter getQuery(PrestoGlobalState globalState) { + return new PrestoUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + PrestoTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyColumnSubset(); + gen = new PrestoTypedExpressionGenerator(globalState).setColumns(table.getColumns()); + sb.append("UPDATE "); + sb.append(table.getName()); + sb.append(" SET "); + updateColumns(columns); + PrestoErrors.addInsertErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, false, false); + } + + @Override + protected void updateValue(PrestoColumn column) { + PrestoExpression expr; + if (Randomly.getBooleanWithSmallProbability()) { + expr = gen.generateExpression(column.getType()); + PrestoErrors.addExpressionErrors(errors); + } else { + expr = gen.generateConstant(column.getType()); + } + sb.append(PrestoToStringVisitor.asString(expr)); + } + +} diff --git a/src/sqlancer/presto/gen/PrestoViewGenerator.java b/src/sqlancer/presto/gen/PrestoViewGenerator.java new file mode 100644 index 000000000..72130cf82 --- /dev/null +++ b/src/sqlancer/presto/gen/PrestoViewGenerator.java @@ -0,0 +1,36 @@ +package sqlancer.presto.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoToStringVisitor; + +public final class PrestoViewGenerator { + + private PrestoViewGenerator() { + } + + public static SQLQueryAdapter generate(PrestoGlobalState globalState) { + int nrColumns = Randomly.smallNumber() + 1; + StringBuilder sb = new StringBuilder("CREATE "); + sb.append("VIEW "); + sb.append(globalState.getSchema().getFreeViewName()); + sb.append("("); + for (int i = 0; i < nrColumns; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append("c"); + sb.append(i); + } + sb.append(") AS "); + sb.append(PrestoToStringVisitor.asString(PrestoRandomQuerySynthesizer.generateSelect(globalState, nrColumns))); + ExpectedErrors errors = new ExpectedErrors(); + PrestoErrors.addExpressionErrors(errors); + PrestoErrors.addGroupByErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, true, false); + } + +} diff --git a/src/sqlancer/presto/test/PrestoQueryPartitioningAggregateTester.java b/src/sqlancer/presto/test/PrestoQueryPartitioningAggregateTester.java new file mode 100644 index 000000000..7ab23286a --- /dev/null +++ b/src/sqlancer/presto/test/PrestoQueryPartitioningAggregateTester.java @@ -0,0 +1,203 @@ +package sqlancer.presto.test; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema.PrestoCompositeDataType; +import sqlancer.presto.PrestoSchema.PrestoDataType; +import sqlancer.presto.PrestoToStringVisitor; +import sqlancer.presto.ast.PrestoAggregateFunction; +import sqlancer.presto.ast.PrestoAlias; +import sqlancer.presto.ast.PrestoCastFunction; +import sqlancer.presto.ast.PrestoExpression; +import sqlancer.presto.ast.PrestoFunctionNode; +import sqlancer.presto.ast.PrestoSelect; +import sqlancer.presto.ast.PrestoUnaryPostfixOperation; +import sqlancer.presto.ast.PrestoUnaryPrefixOperation; + +public class PrestoQueryPartitioningAggregateTester extends PrestoQueryPartitioningBase + implements TestOracle { + + private String firstResult; + private String firstResultType; + private String secondResult; + private String originalQuery; + private String metamorphicQuery; + + public PrestoQueryPartitioningAggregateTester(PrestoGlobalState state) { + super(state); + PrestoErrors.addGroupByErrors(errors); + PrestoErrors.addExpressionErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + PrestoAggregateFunction aggregateFunction = Randomly.fromOptions(PrestoAggregateFunction.MAX, + PrestoAggregateFunction.MIN, PrestoAggregateFunction.SUM, PrestoAggregateFunction.COUNT, + PrestoAggregateFunction.AVG/* , PrestoAggregateFunction.STDDEV_POP */); + List aggregateArgs = gen.generateArgsForAggregate(aggregateFunction); + PrestoFunctionNode aggregate = new PrestoFunctionNode<>(aggregateArgs, + aggregateFunction); + select.setFetchColumns(List.of(aggregate)); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + originalQuery = PrestoToStringVisitor.asString(select); + firstResult = getAggregateResult(originalQuery); + firstResultType = getAggregateResultType(originalQuery); + metamorphicQuery = createMetamorphicUnionQuery(select, aggregate, select.getFromList()); + secondResult = getAggregateResult(metamorphicQuery); + + state.getState().getLocalState().log( + "--" + originalQuery + ";\n--" + metamorphicQuery + "\n-- " + firstResult + "\n-- " + secondResult); + if (firstResultType.equals("VARBINARY") || firstResultType.equals("ARRAY(VARBINARY)") + || firstResultType.equals("ARRAY(ARRAY(VARBINARY))")) { + throw new IgnoreMeException(); + } + if (firstResult == null && secondResult != null) { + if (secondResult.contains("Inf")) { + throw new IgnoreMeException(); // FIXME: average computation + } + throw new AssertionError(); + } else if (firstResult != null && !firstResult.contentEquals(secondResult) + && !ComparatorHelper.isEqualDouble(firstResult, secondResult)) { + if (secondResult.contains("Inf")) { + throw new IgnoreMeException(); // FIXME: average computation + } + throw new AssertionError(); + } + + } + + private String createMetamorphicUnionQuery(PrestoSelect select, + PrestoFunctionNode aggregate, List from) { + String metamorphicQuery; + PrestoExpression whereClause = gen.generatePredicate(); + PrestoExpression negatedClause = new PrestoUnaryPrefixOperation(whereClause, + PrestoUnaryPrefixOperation.PrestoUnaryPrefixOperator.NOT); + PrestoExpression notNullClause = new PrestoUnaryPostfixOperation(whereClause, + PrestoUnaryPostfixOperation.PrestoUnaryPostfixOperator.IS_NULL); + List mappedAggregate = mapped(aggregate); + PrestoSelect leftSelect = getSelect(mappedAggregate, from, whereClause, select.getJoinList()); + PrestoSelect middleSelect = getSelect(mappedAggregate, from, negatedClause, select.getJoinList()); + PrestoSelect rightSelect = getSelect(mappedAggregate, from, notNullClause, select.getJoinList()); + metamorphicQuery = "SELECT " + getOuterAggregateFunction(aggregate) + " FROM ("; + metamorphicQuery += PrestoToStringVisitor.asString(leftSelect) + " UNION ALL " + + PrestoToStringVisitor.asString(middleSelect) + " UNION ALL " + + PrestoToStringVisitor.asString(rightSelect); + metamorphicQuery += ") as asdf"; + return metamorphicQuery; + } + + private String getAggregateResult(String queryString) { + String resultString; + SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors, false, false); + try (SQLancerResultSet result = q.executeAndGet(state)) { + if (result == null) { + throw new IgnoreMeException(); + } + if (!result.next()) { + resultString = null; + } else { + resultString = result.getString(1); + } + return resultString; + } catch (SQLException e) { + if (errors.errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } + + if (!e.getMessage().contains("Not implemented type")) { + throw new AssertionError(queryString, e); + } else { + throw new IgnoreMeException(); + } + } + } + + private String getAggregateResultType(String queryString) { + String resultString; + SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors, false, false); + try (SQLancerResultSet result = q.executeAndGet(state)) { + if (result == null) { + throw new IgnoreMeException(); + } + if (!result.next()) { + resultString = null; + } else { + resultString = result.getType(1); + } + return resultString; + } catch (SQLException e) { + if (!e.getMessage().contains("Not implemented type")) { + throw new AssertionError(queryString, e); + } else { + throw new IgnoreMeException(); + } + } + } + + private List mapped(PrestoFunctionNode aggregate) { + PrestoCastFunction count; + switch (aggregate.getFunc()) { + case COUNT: + case MAX: + case MIN: + case SUM: + return aliasArgs(List.of(aggregate)); + case AVG: + PrestoFunctionNode sum = new PrestoFunctionNode<>(aggregate.getArgs(), + PrestoAggregateFunction.SUM); + count = new PrestoCastFunction(new PrestoFunctionNode<>(aggregate.getArgs(), PrestoAggregateFunction.COUNT), + new PrestoCompositeDataType(PrestoDataType.FLOAT, 8, 0)); + return aliasArgs(Arrays.asList(sum, count)); + default: + throw new AssertionError(aggregate.getFunc()); + } + } + + private List aliasArgs(List originalAggregateArgs) { + List args = new ArrayList<>(); + int i = 0; + for (PrestoExpression expr : originalAggregateArgs) { + args.add(new PrestoAlias(expr, "agg" + i++)); + } + return args; + } + + private String getOuterAggregateFunction(PrestoFunctionNode aggregate) { + switch (aggregate.getFunc()) { + case AVG: + return "SUM(CAST(agg0 AS DOUBLE))/CAST(SUM(agg1) AS DOUBLE)"; + case COUNT: + return PrestoAggregateFunction.SUM + "(agg0)"; + default: + return aggregate.getFunc().toString() + "(agg0)"; + } + } + + private PrestoSelect getSelect(List aggregates, List from, + PrestoExpression whereClause, List joinList) { + PrestoSelect leftSelect = new PrestoSelect(); + leftSelect.setFetchColumns(aggregates); + leftSelect.setFromList(from); + leftSelect.setWhereClause(whereClause); + leftSelect.setJoinList(joinList); + if (Randomly.getBooleanWithSmallProbability()) { + leftSelect.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + } + return leftSelect; + } + +} diff --git a/src/sqlancer/presto/test/PrestoQueryPartitioningBase.java b/src/sqlancer/presto/test/PrestoQueryPartitioningBase.java new file mode 100644 index 000000000..0e80d5c41 --- /dev/null +++ b/src/sqlancer/presto/test/PrestoQueryPartitioningBase.java @@ -0,0 +1,88 @@ +package sqlancer.presto.test; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoColumn; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.PrestoSchema.PrestoTables; +import sqlancer.presto.ast.PrestoColumnReference; +import sqlancer.presto.ast.PrestoExpression; +import sqlancer.presto.ast.PrestoJoin; +import sqlancer.presto.ast.PrestoSelect; +import sqlancer.presto.ast.PrestoTableReference; +import sqlancer.presto.gen.PrestoTypedExpressionGenerator; + +public class PrestoQueryPartitioningBase extends TernaryLogicPartitioningOracleBase + implements TestOracle { + + PrestoSchema s; + PrestoTables targetTables; + PrestoTypedExpressionGenerator gen; + PrestoSelect select; + + public PrestoQueryPartitioningBase(PrestoGlobalState state) { + super(state); + PrestoErrors.addExpressionErrors(errors); + } + + public static String canonicalizeResultValue(String value) { + if (value == null) { + return null; + } + + // TODO: check this + switch (value) { + case "-0.0": + return "0.0"; + case "-0": + return "0"; + default: + } + + return value; + } + + @Override + public void check() throws SQLException { + s = state.getSchema(); + targetTables = s.getRandomTableNonEmptyTables(); + gen = new PrestoTypedExpressionGenerator(state).setColumns(targetTables.getColumns()); + initializeTernaryPredicateVariants(); + select = new PrestoSelect(); + select.setFetchColumns(generateFetchColumns()); + List tables = targetTables.getTables(); + List tableList = tables.stream().map(t -> new PrestoTableReference(t)) + .collect(Collectors.toList()); + List joins = PrestoJoin.getJoins(tableList, state).stream().collect(Collectors.toList()); + select.setJoinList(new ArrayList<>(joins)); + select.setFromList(new ArrayList<>(tableList)); + select.setWhereClause(null); + } + + List generateFetchColumns() { + List columns = new ArrayList<>(); + if (Randomly.getBoolean()) { + columns.add(new PrestoColumnReference(new PrestoColumn("*", null, false, false))); + } else { + columns = Randomly.nonEmptySubset(targetTables.getColumns()).stream().map(c -> new PrestoColumnReference(c)) + .collect(Collectors.toList()); + } + return columns; + } + + @Override + protected ExpressionGenerator getGen() { + return gen; + } + +} diff --git a/src/sqlancer/presto/test/PrestoQueryPartitioningDistinctTester.java b/src/sqlancer/presto/test/PrestoQueryPartitioningDistinctTester.java new file mode 100644 index 000000000..008e8f4db --- /dev/null +++ b/src/sqlancer/presto/test/PrestoQueryPartitioningDistinctTester.java @@ -0,0 +1,44 @@ +package sqlancer.presto.test; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoToStringVisitor; + +public class PrestoQueryPartitioningDistinctTester extends PrestoQueryPartitioningBase { + + public PrestoQueryPartitioningDistinctTester(PrestoGlobalState state) { + super(state); + PrestoErrors.addGroupByErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + select.setDistinct(true); + select.setWhereClause(null); + String originalQueryString = PrestoToStringVisitor.asString(select); + + List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + if (Randomly.getBoolean()) { + select.setDistinct(false); + } + select.setWhereClause(predicate); + String firstQueryString = PrestoToStringVisitor.asString(select); + select.setWhereClause(negatedPredicate); + String secondQueryString = PrestoToStringVisitor.asString(select); + select.setWhereClause(isNullPredicate); + String thirdQueryString = PrestoToStringVisitor.asString(select); + List combinedString = new ArrayList<>(); + List secondResultSet = ComparatorHelper.getCombinedResultSetNoDuplicates(firstQueryString, + secondQueryString, thirdQueryString, combinedString, true, state, errors); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state, PrestoQueryPartitioningBase::canonicalizeResultValue); + } + +} diff --git a/src/sqlancer/presto/test/PrestoQueryPartitioningGroupByTester.java b/src/sqlancer/presto/test/PrestoQueryPartitioningGroupByTester.java new file mode 100644 index 000000000..6c3f4bf71 --- /dev/null +++ b/src/sqlancer/presto/test/PrestoQueryPartitioningGroupByTester.java @@ -0,0 +1,51 @@ +package sqlancer.presto.test; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoToStringVisitor; +import sqlancer.presto.ast.PrestoColumnReference; +import sqlancer.presto.ast.PrestoExpression; + +public class PrestoQueryPartitioningGroupByTester extends PrestoQueryPartitioningBase { + + public PrestoQueryPartitioningGroupByTester(PrestoGlobalState state) { + super(state); + PrestoErrors.addGroupByErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + select.setGroupByExpressions(select.getFetchColumns()); + select.setWhereClause(null); + String originalQueryString = PrestoToStringVisitor.asString(select); + + List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + + select.setWhereClause(predicate); + String firstQueryString = PrestoToStringVisitor.asString(select); + select.setWhereClause(negatedPredicate); + String secondQueryString = PrestoToStringVisitor.asString(select); + select.setWhereClause(isNullPredicate); + String thirdQueryString = PrestoToStringVisitor.asString(select); + List combinedString = new ArrayList<>(); + List secondResultSet = ComparatorHelper.getCombinedResultSetNoDuplicates(firstQueryString, + secondQueryString, thirdQueryString, combinedString, true, state, errors); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state, PrestoQueryPartitioningBase::canonicalizeResultValue); + } + + @Override + List generateFetchColumns() { + return Randomly.nonEmptySubset(targetTables.getColumns()).stream().map(c -> new PrestoColumnReference(c)) + .collect(Collectors.toList()); + } + +} diff --git a/src/sqlancer/presto/test/PrestoQueryPartitioningHavingTester.java b/src/sqlancer/presto/test/PrestoQueryPartitioningHavingTester.java new file mode 100644 index 000000000..b53bfc07f --- /dev/null +++ b/src/sqlancer/presto/test/PrestoQueryPartitioningHavingTester.java @@ -0,0 +1,63 @@ +package sqlancer.presto.test; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoToStringVisitor; +import sqlancer.presto.ast.PrestoExpression; + +public class PrestoQueryPartitioningHavingTester extends PrestoQueryPartitioningBase + implements TestOracle { + + public PrestoQueryPartitioningHavingTester(PrestoGlobalState state) { + super(state); + PrestoErrors.addGroupByErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(PrestoSchema.PrestoCompositeDataType.getRandomWithoutNull())); + } + boolean orderBy = Randomly.getBoolean(); + if (orderBy) { + select.setOrderByClauses(gen.generateOrderBys()); + } + select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + select.setHavingClause(null); + String originalQueryString = PrestoToStringVisitor.asString(select); + List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + + select.setHavingClause(predicate); + String firstQueryString = PrestoToStringVisitor.asString(select); + select.setHavingClause(negatedPredicate); + String secondQueryString = PrestoToStringVisitor.asString(select); + select.setHavingClause(isNullPredicate); + String thirdQueryString = PrestoToStringVisitor.asString(select); + List combinedString = new ArrayList<>(); + List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, + thirdQueryString, combinedString, !orderBy, state, errors); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state, PrestoQueryPartitioningBase::canonicalizeResultValue); + } + + @Override + protected PrestoExpression generatePredicate() { + return gen.generateHavingClause(); + } + + @Override + List generateFetchColumns() { + return Collections.singletonList(gen.generateHavingClause()); + } + +} diff --git a/src/sqlancer/presto/test/PrestoQueryPartitioningWhereTester.java b/src/sqlancer/presto/test/PrestoQueryPartitioningWhereTester.java new file mode 100644 index 000000000..1fb2b7018 --- /dev/null +++ b/src/sqlancer/presto/test/PrestoQueryPartitioningWhereTester.java @@ -0,0 +1,45 @@ +package sqlancer.presto.test; + +import java.sql.SQLException; + +import sqlancer.Reproducer; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.presto.PrestoErrors; +import sqlancer.presto.PrestoGlobalState; +import sqlancer.presto.PrestoSchema; +import sqlancer.presto.PrestoSchema.PrestoColumn; +import sqlancer.presto.PrestoSchema.PrestoTable; +import sqlancer.presto.ast.PrestoExpression; +import sqlancer.presto.ast.PrestoJoin; +import sqlancer.presto.ast.PrestoSelect; +import sqlancer.presto.gen.PrestoTypedExpressionGenerator; + +public class PrestoQueryPartitioningWhereTester implements TestOracle { + + private final TLPWhereOracle oracle; + + public PrestoQueryPartitioningWhereTester(PrestoGlobalState state) { + PrestoTypedExpressionGenerator gen = new PrestoTypedExpressionGenerator(state); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(PrestoErrors.getExpressionErrors()) + .with(PrestoErrors.getGroupByErrors()).build(); + + this.oracle = new TLPWhereOracle<>(state, gen, expectedErrors); + } + + @Override + public void check() throws SQLException { + oracle.check(); + } + + @Override + public String getLastQueryString() { + return oracle.getLastQueryString(); + } + + @Override + public Reproducer getLastReproducer() { + return oracle.getLastReproducer(); + } +} diff --git a/src/sqlancer/questdb/QuestDBBugs.java b/src/sqlancer/questdb/QuestDBBugs.java new file mode 100644 index 000000000..1bd565823 --- /dev/null +++ b/src/sqlancer/questdb/QuestDBBugs.java @@ -0,0 +1,7 @@ +package sqlancer.questdb; + +public final class QuestDBBugs { + + private QuestDBBugs() { + } +} diff --git a/src/sqlancer/questdb/QuestDBErrors.java b/src/sqlancer/questdb/QuestDBErrors.java new file mode 100644 index 000000000..213f06ad4 --- /dev/null +++ b/src/sqlancer/questdb/QuestDBErrors.java @@ -0,0 +1,56 @@ +package sqlancer.questdb; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class QuestDBErrors { + + private QuestDBErrors() { + } + + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + // TODO (anxing) + errors.add("unexpected argument for function: "); + errors.add("unexpected token:"); // SELECT FROM multiple tables without WHERE/ JOIN clause + errors.add("boolean expression expected"); + errors.add("Column name expected"); + errors.add("too few arguments for 'in'"); + errors.add("cannot compare TIMESTAMP with type"); // WHERE column IN with nonTIMESTAMP arg + errors.add("constant expected"); + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + public static List getGroupByErrors() { + // TODO (anxing) + + return new ArrayList<>(); + } + + public static void addGroupByErrors(ExpectedErrors errors) { + errors.addAll(getGroupByErrors()); + } + + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + + // TODO (anxing) + errors.add("Invalid column"); + errors.add("inconvertible types:"); + errors.add("inconvertible value:"); + + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); + } +} diff --git a/src/sqlancer/questdb/QuestDBOptions.java b/src/sqlancer/questdb/QuestDBOptions.java new file mode 100644 index 000000000..82cf55bd9 --- /dev/null +++ b/src/sqlancer/questdb/QuestDBOptions.java @@ -0,0 +1,39 @@ +package sqlancer.questdb; + +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; + +@Parameters(separators = "=", commandDescription = "QuestDB (default port: " + QuestDBOptions.DEFAULT_PORT + + " default host: " + QuestDBOptions.DEFAULT_HOST + ")") +public class QuestDBOptions implements DBMSSpecificOptions { + public static final String DEFAULT_HOST = "localhost"; + public static final int DEFAULT_PORT = 8812; + + @Parameter(names = "--oracle") + public List oracle = Arrays.asList(QuestDBOracleFactory.WHERE); + + @Parameter(names = "--username", description = "The user name used to log into QuestDB") + private String userName = "admin"; // NOPMD + + @Parameter(names = "--password", description = "The password used to log into QuestDB") + private String password = "quest"; // NOPMD + + @Override + public List getTestOracleFactory() { + return oracle; + } + + public String getUserName() { + return userName; + } + + public String getPassword() { + return password; + } + +} diff --git a/src/sqlancer/questdb/QuestDBOracleFactory.java b/src/sqlancer/questdb/QuestDBOracleFactory.java new file mode 100644 index 000000000..52c727278 --- /dev/null +++ b/src/sqlancer/questdb/QuestDBOracleFactory.java @@ -0,0 +1,18 @@ +package sqlancer.questdb; + +import java.sql.SQLException; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.TestOracle; +import sqlancer.questdb.test.QuestDBQueryPartitioningWhereTester; + +public enum QuestDBOracleFactory implements OracleFactory { + // TODO (anxing): implement test oracles + WHERE { + @Override + public TestOracle create(QuestDBProvider.QuestDBGlobalState globalState) + throws SQLException { + return new QuestDBQueryPartitioningWhereTester(globalState); + } + } +} diff --git a/src/sqlancer/questdb/QuestDBProvider.java b/src/sqlancer/questdb/QuestDBProvider.java new file mode 100644 index 000000000..79c1554be --- /dev/null +++ b/src/sqlancer/questdb/QuestDBProvider.java @@ -0,0 +1,154 @@ +package sqlancer.questdb; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; +import sqlancer.questdb.gen.QuestDBAlterIndexGenerator; +import sqlancer.questdb.gen.QuestDBInsertGenerator; +import sqlancer.questdb.gen.QuestDBTableGenerator; +import sqlancer.questdb.gen.QuestDBTruncateGenerator; + +@AutoService(DatabaseProvider.class) +public class QuestDBProvider extends SQLProviderAdapter { + public QuestDBProvider() { + super(QuestDBGlobalState.class, QuestDBOptions.class); + } + + public enum Action implements AbstractAction { + INSERT(QuestDBInsertGenerator::getQuery), // + ALTER_INDEX(QuestDBAlterIndexGenerator::getQuery), // + TRUNCATE(QuestDBTruncateGenerator::generate); // + // TODO (anxing): maybe implement these later + // UPDATE(QuestDBUpdateGenerator::getQuery), // + // CREATE_VIEW(QuestDBViewGenerator::generate), // + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(QuestDBGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + + private static int mapActions(QuestDBGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + switch (a) { + case INSERT: + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + case ALTER_INDEX: + return r.getInteger(0, 3); + case TRUNCATE: + return r.getInteger(0, 5); + default: + throw new AssertionError("Unknown action: " + a); + } + } + + public static class QuestDBGlobalState extends SQLGlobalState { + + @Override + protected QuestDBSchema readSchema() throws SQLException { + return QuestDBSchema.fromConnection(getConnection(), getDatabaseName()); + } + + } + + @Override + public void generateDatabase(QuestDBGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { + boolean success; + do { + SQLQueryAdapter qt = new QuestDBTableGenerator().getQuery(globalState, null); + success = globalState.executeStatement(qt); + } while (!success); + } + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + QuestDBProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public SQLConnection createDatabase(QuestDBGlobalState globalState) throws Exception { + String host = globalState.getOptions().getHost(); + int port = globalState.getOptions().getPort(); + if (host == null) { + host = QuestDBOptions.DEFAULT_HOST; + } + if (port == sqlancer.MainOptions.NO_SET_PORT) { + port = QuestDBOptions.DEFAULT_PORT; + } + // TODO(anxing): maybe not hardcode here... + String databaseName = "qdb"; + String tableName = "sqlancer_test"; + String url = String.format("jdbc:postgresql://%s:%d/%s", host, port, databaseName); + // use QuestDB default username & password for Postgres JDBC + Properties properties = new Properties(); + properties.setProperty("user", globalState.getDbmsSpecificOptions().getUserName()); + properties.setProperty("password", globalState.getDbmsSpecificOptions().getPassword()); + properties.setProperty("sslmode", "disable"); + + Connection con = DriverManager.getConnection(url, properties); + // QuestDB cannot create or drop `DATABASE`, can only create or drop `TABLE` + globalState.getState().logStatement("DROP TABLE IF EXISTS " + tableName + " CASCADE"); + SQLQueryAdapter createTableCommand = new QuestDBTableGenerator().getQuery(globalState, tableName); + globalState.getState().logStatement(createTableCommand); + globalState.getState().logStatement("DROP TABLE IF EXISTS " + tableName); + + try (Statement s = con.createStatement()) { + s.execute("DROP TABLE IF EXISTS " + tableName); + } + // TODO(anxing): Drop all previous tables in db + // List tableNames = + // globalState.getSchema().getDatabaseTables().stream().map(AbstractTable::getName).collect(Collectors.toList()); + // for (String tName : tableNames) { + // try (Statement s = con.createStatement()) { + // String query = "DROP TABLE IF EXISTS " + tName; + // globalState.getState().logStatement(query); + // s.execute(query); + // } + // } + try (Statement s = con.createStatement()) { + s.execute(createTableCommand.getQueryString()); + } + // drop test table + try (Statement s = con.createStatement()) { + s.execute("DROP TABLE IF EXISTS " + tableName); + } + con.close(); + con = DriverManager.getConnection(url, properties); + return new SQLConnection(con); + } + + @Override + public String getDBMSName() { + return "questdb"; + } + +} diff --git a/src/sqlancer/questdb/QuestDBSchema.java b/src/sqlancer/questdb/QuestDBSchema.java new file mode 100644 index 000000000..55ee01aab --- /dev/null +++ b/src/sqlancer/questdb/QuestDBSchema.java @@ -0,0 +1,314 @@ +package sqlancer.questdb; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.DBMSCommon; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; +import sqlancer.questdb.QuestDBSchema.QuestDBTable; + +public class QuestDBSchema extends AbstractSchema { + + public enum QuestDBDataType { + + BOOLEAN, // CHAR, + /* STRING, */ + INT, FLOAT, SYMBOL, + // DATE, TIMESTAMP, + /* GEOHASH, */ + NULL; + + public static QuestDBDataType getRandomWithoutNull() { + QuestDBDataType dt; + do { + dt = Randomly.fromOptions(values()); + } while (dt == QuestDBDataType.NULL); + return dt; + } + + } + + public static class QuestDBCompositeDataType { + + private final QuestDBDataType dataType; + + private final int size; + + private final boolean isNullable; + + public QuestDBCompositeDataType(QuestDBDataType dataType, int size) { + this.dataType = dataType; + this.size = size; + + switch (dataType) { + case INT: + switch (size) { + case 1: + case 2: + isNullable = false; + break; + default: + isNullable = true; + break; + } + break; + case BOOLEAN: + isNullable = false; + break; + case SYMBOL: + isNullable = true; + break; + default: + isNullable = true; + } + } + + public QuestDBDataType getPrimitiveDataType() { + return dataType; + } + + public int getSize() { + if (size == -1) { + throw new AssertionError(this); + } + return size; + } + + public boolean isNullable() { + return isNullable; + } + + public static QuestDBCompositeDataType getRandomWithoutNull() { + QuestDBDataType type = QuestDBDataType.getRandomWithoutNull(); + int size = -1; + switch (type) { + case INT: + size = Randomly.fromOptions(1, 2, 4); + break; + case FLOAT: + size = Randomly.fromOptions(4, 8, 32); + break; + case BOOLEAN: + // case CHAR: + // case DATE: + // case TIMESTAMP: + size = 0; + break; + case SYMBOL: + size = 0; + break; + default: + throw new AssertionError(type); + } + + return new QuestDBCompositeDataType(type, size); + } + + @Override + public String toString() { + switch (getPrimitiveDataType()) { + case INT: + switch (size) { + case 1: + return Randomly.fromOptions("BYTE"); + case 2: + return Randomly.fromOptions("SHORT"); + case 4: + return Randomly.fromOptions("INT"); + default: + throw new AssertionError(size); + } + // case CHAR: + // return "CHAR"; + case FLOAT: + switch (size) { + case 4: + return Randomly.fromOptions("FLOAT"); + case 8: + return Randomly.fromOptions(/* "DOUBLE", */"LONG"); + case 32: + return Randomly.fromOptions("LONG256"); + default: + throw new AssertionError(size); + } + case BOOLEAN: + return Randomly.fromOptions("BOOLEAN"); + case SYMBOL: + return "SYMBOL"; + // case TIMESTAMP: + // return Randomly.fromOptions("TIMESTAMP"); + // case DATE: + // return Randomly.fromOptions("DATE"); + case NULL: + return Randomly.fromOptions("NULL"); + default: + throw new AssertionError(getPrimitiveDataType()); + } + } + + } + + public static class QuestDBColumn extends AbstractTableColumn { + private final boolean isIndexed; + private final boolean isNullable; + + public QuestDBColumn(String name, QuestDBCompositeDataType columnType, boolean isIndexed) { + super(name, null, columnType); + this.isIndexed = isIndexed; + this.isNullable = columnType == null || columnType.isNullable(); + } + + public boolean isIndexed() { + return isIndexed; + } + + public boolean isNullable() { + return isNullable; + } + + } + + public static class QuestDBTables extends AbstractTables { + public static final Set RESERVED_TABLES = new HashSet<>( + Arrays.asList("sys.column_versions_purge_log", "telemetry_config", "telemetry", "sys.telemetry_wal")); + + public QuestDBTables(List tables) { + super(tables); + } + } + + public QuestDBSchema(List databaseTables) { + super(databaseTables); + } + + public QuestDBTables getRandomTableNonEmptyTables() { + return new QuestDBTables(Randomly.nonEmptySubset(getDatabaseTables())); + } + + private static QuestDBCompositeDataType getColumnType(String typeString) { + QuestDBDataType primitiveType; + int size = -1; + + switch (typeString) { + case "INT": + primitiveType = QuestDBDataType.INT; + size = 4; + break; + // case "CHAR": + // primitiveType = QuestDBDataType.CHAR; + // break; + case "FLOAT": + primitiveType = QuestDBDataType.FLOAT; + size = 4; + break; + case "LONG": + primitiveType = QuestDBDataType.FLOAT; + size = 8; + break; + case "LONG256": + primitiveType = QuestDBDataType.FLOAT; + size = 32; + break; + case "BOOLEAN": + primitiveType = QuestDBDataType.BOOLEAN; + break; + // case "DATE": + // primitiveType = QuestDBDataType.DATE; + // break; + // case "TIMESTAMP": + // primitiveType = QuestDBDataType.TIMESTAMP; + // break; + case "BYTE": + primitiveType = QuestDBDataType.INT; + size = 1; + break; + case "SHORT": + primitiveType = QuestDBDataType.INT; + size = 2; + break; + case "SYMBOL": + primitiveType = QuestDBDataType.SYMBOL; + break; + case "NULL": + primitiveType = QuestDBDataType.NULL; + break; + default: + throw new AssertionError(typeString); + } + return new QuestDBCompositeDataType(primitiveType, size); + } + + public static class QuestDBTable extends AbstractRelationalTable { + + public QuestDBTable(String tableName, List columns, boolean isView) { + super(tableName, columns, Collections.emptyList(), isView); + } + + } + + public static QuestDBSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + List databaseTables = new ArrayList<>(); + List tableNames = getTableNames(con); + for (String tableName : tableNames) { + if (DBMSCommon.matchesIndexName(tableName)) { + continue; // TODO: unexpected? + } + List databaseColumns = getTableColumns(con, tableName); + boolean isView = matchesViewName(tableName); + QuestDBTable t = new QuestDBTable(tableName, databaseColumns, isView); + for (QuestDBColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + + } + return new QuestDBSchema(databaseTables); + } + + protected static List getTableNames(SQLConnection con) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery("SHOW TABLES;")) { + while (rs.next()) { + String tName = rs.getString("table"); + // exclude reserved tables for testing + if (!QuestDBTables.RESERVED_TABLES.contains(tName)) { + tableNames.add(tName); + } + } + } + } + return tableNames; + } + + private static List getTableColumns(SQLConnection con, String tableName) throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String.format("SHOW COLUMNS FROM %s;", tableName))) { + while (rs.next()) { + String columnName = rs.getString("column"); + String dataType = rs.getString("type"); + boolean isIndexed = rs.getString("indexed").contains("true"); + QuestDBColumn c = new QuestDBColumn(columnName, getColumnType(dataType), isIndexed); + columns.add(c); + } + } + } + return columns; + } + +} diff --git a/src/sqlancer/questdb/QuestDBToStringVisitor.java b/src/sqlancer/questdb/QuestDBToStringVisitor.java new file mode 100644 index 000000000..9363b30b1 --- /dev/null +++ b/src/sqlancer/questdb/QuestDBToStringVisitor.java @@ -0,0 +1,66 @@ +package sqlancer.questdb; + +import sqlancer.common.ast.newast.NewToStringVisitor; +import sqlancer.questdb.ast.QuestDBConstant; +import sqlancer.questdb.ast.QuestDBExpression; +import sqlancer.questdb.ast.QuestDBSelect; + +public class QuestDBToStringVisitor extends NewToStringVisitor { + + @Override + public void visitSpecific(QuestDBExpression expr) { + if (expr instanceof QuestDBConstant) { + visit((QuestDBConstant) expr); + } else if (expr instanceof QuestDBSelect) { + visit((QuestDBSelect) expr); + } else { // TODO: maybe implement QuestDBJoin + throw new AssertionError("Unknown class: " + expr.getClass()); + } + } + + private void visit(QuestDBConstant constant) { + sb.append(constant.toString()); + } + + private void visit(QuestDBSelect select) { + sb.append("SELECT "); + if (select.isDistinct()) { + sb.append("DISTINCT "); + } + visit(select.getFetchColumns()); + sb.append(" FROM "); + visit(select.getFromList()); + if (!select.getFromList().isEmpty() && !select.getJoinList().isEmpty()) { + sb.append(", "); + } + if (!select.getJoinList().isEmpty()) { + visit(select.getJoinList()); + } + if (select.getWhereClause() != null) { + sb.append(" WHERE "); + visit(select.getWhereClause()); + } + // if (!select.getGroupByExpressions().isEmpty()) { + // sb.append(" GROUP BY "); + // visit(select.getGroupByExpressions()); + // } + // if (select.getHavingClause() != null) { + // sb.append(" HAVING "); + // visit(select.getHavingClause()); + // } + // if (!select.getOrderByClauses().isEmpty()) { + // sb.append(" ORDER BY "); + // visit(select.getOrderByClauses()); + // } + if (select.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(select.getLimitClause()); + } + } + + public static String asString(QuestDBExpression expr) { + QuestDBToStringVisitor visitor = new QuestDBToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } +} diff --git a/src/sqlancer/questdb/ast/QuestDBBinaryOperation.java b/src/sqlancer/questdb/ast/QuestDBBinaryOperation.java new file mode 100644 index 000000000..1b69b2c31 --- /dev/null +++ b/src/sqlancer/questdb/ast/QuestDBBinaryOperation.java @@ -0,0 +1,10 @@ +package sqlancer.questdb.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class QuestDBBinaryOperation extends NewBinaryOperatorNode implements QuestDBExpression { + public QuestDBBinaryOperation(QuestDBExpression left, QuestDBExpression right, Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/questdb/ast/QuestDBColumnReference.java b/src/sqlancer/questdb/ast/QuestDBColumnReference.java new file mode 100644 index 000000000..d627538a0 --- /dev/null +++ b/src/sqlancer/questdb/ast/QuestDBColumnReference.java @@ -0,0 +1,12 @@ +package sqlancer.questdb.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.questdb.QuestDBSchema; + +public class QuestDBColumnReference extends ColumnReferenceNode + implements QuestDBExpression { + public QuestDBColumnReference(QuestDBSchema.QuestDBColumn column) { + super(column); + } + +} diff --git a/src/sqlancer/questdb/ast/QuestDBConstant.java b/src/sqlancer/questdb/ast/QuestDBConstant.java new file mode 100644 index 000000000..f2e65fbc0 --- /dev/null +++ b/src/sqlancer/questdb/ast/QuestDBConstant.java @@ -0,0 +1,111 @@ +package sqlancer.questdb.ast; + +public class QuestDBConstant implements QuestDBExpression { + private QuestDBConstant() { + } + + public static class QuestDBNullConstant extends QuestDBConstant { + @Override + public String toString() { + return "NULL"; + } + } + + public static class QuestDBIntConstant extends QuestDBConstant { + private final long value; + + public QuestDBIntConstant(long value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + public long getValue() { + return value; + } + } + + public static class QuestDBBooleanConstant extends QuestDBConstant { + private final boolean value; + + public QuestDBBooleanConstant(boolean value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + public boolean getValue() { + return value; + } + } + + public static class QuestDBSymbolConstant extends QuestDBConstant { + private final String value; + + public QuestDBSymbolConstant(String value) { + this.value = value; + } + + @Override + public String toString() { + if (value.equals("")) { + return "NULL"; + } + return "'" + value + "'"; + } + + public String getValue() { + return value; + } + } + + public static QuestDBExpression createIntConstant(long val) { + return new QuestDBIntConstant(val); + } + + public static class QuestDBDoubleConstant extends QuestDBConstant { + + private final double value; + + public QuestDBDoubleConstant(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "cast('Infinity' as double)"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "cast('-Infinity' as double)"; + } + return String.valueOf(value); + } + + } + + public static QuestDBExpression createBooleanConstant(boolean val) { + return new QuestDBBooleanConstant(val); + } + + public static QuestDBExpression createNullConstant() { + return new QuestDBNullConstant(); + } + + public static QuestDBExpression createFloatConstant(double val) { + return new QuestDBDoubleConstant(val); + } + + public static QuestDBExpression createSymbolConstant(String val) { + return new QuestDBSymbolConstant(val); + } +} diff --git a/src/sqlancer/questdb/ast/QuestDBExpression.java b/src/sqlancer/questdb/ast/QuestDBExpression.java new file mode 100644 index 000000000..c2c7f1d74 --- /dev/null +++ b/src/sqlancer/questdb/ast/QuestDBExpression.java @@ -0,0 +1,4 @@ +package sqlancer.questdb.ast; + +public interface QuestDBExpression { +} diff --git a/src/sqlancer/questdb/ast/QuestDBInOperation.java b/src/sqlancer/questdb/ast/QuestDBInOperation.java new file mode 100644 index 000000000..ee4ba2b59 --- /dev/null +++ b/src/sqlancer/questdb/ast/QuestDBInOperation.java @@ -0,0 +1,11 @@ +package sqlancer.questdb.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; + +public class QuestDBInOperation extends NewInOperatorNode implements QuestDBExpression { + public QuestDBInOperation(QuestDBExpression left, List right, boolean isNegated) { + super(left, right, isNegated); + } +} diff --git a/src/sqlancer/questdb/ast/QuestDBSelect.java b/src/sqlancer/questdb/ast/QuestDBSelect.java new file mode 100644 index 000000000..0684b4082 --- /dev/null +++ b/src/sqlancer/questdb/ast/QuestDBSelect.java @@ -0,0 +1,15 @@ +package sqlancer.questdb.ast; + +import sqlancer.common.ast.SelectBase; + +public class QuestDBSelect extends SelectBase implements QuestDBExpression { + private boolean isDistinct; + + public void setDistinct(boolean distinct) { + isDistinct = distinct; + } + + public boolean isDistinct() { + return isDistinct; + } +} diff --git a/src/sqlancer/questdb/ast/QuestDBTableReference.java b/src/sqlancer/questdb/ast/QuestDBTableReference.java new file mode 100644 index 000000000..8e4d387da --- /dev/null +++ b/src/sqlancer/questdb/ast/QuestDBTableReference.java @@ -0,0 +1,11 @@ +package sqlancer.questdb.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.questdb.QuestDBSchema; + +public class QuestDBTableReference extends TableReferenceNode + implements QuestDBExpression { + public QuestDBTableReference(QuestDBSchema.QuestDBTable table) { + super(table); + } +} diff --git a/src/sqlancer/questdb/ast/QuestDBUnaryPostfixOperation.java b/src/sqlancer/questdb/ast/QuestDBUnaryPostfixOperation.java new file mode 100644 index 000000000..0301ba60a --- /dev/null +++ b/src/sqlancer/questdb/ast/QuestDBUnaryPostfixOperation.java @@ -0,0 +1,11 @@ +package sqlancer.questdb.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; + +public class QuestDBUnaryPostfixOperation extends NewUnaryPostfixOperatorNode + implements QuestDBExpression { + public QuestDBUnaryPostfixOperation(QuestDBExpression expr, BinaryOperatorNode.Operator op) { + super(expr, op); + } +} diff --git a/src/sqlancer/questdb/ast/QuestDBUnaryPrefixOperation.java b/src/sqlancer/questdb/ast/QuestDBUnaryPrefixOperation.java new file mode 100644 index 000000000..8488a0de8 --- /dev/null +++ b/src/sqlancer/questdb/ast/QuestDBUnaryPrefixOperation.java @@ -0,0 +1,11 @@ +package sqlancer.questdb.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; + +public class QuestDBUnaryPrefixOperation extends NewUnaryPrefixOperatorNode + implements QuestDBExpression { + public QuestDBUnaryPrefixOperation(QuestDBExpression expr, BinaryOperatorNode.Operator operator) { + super(expr, operator); + } +} diff --git a/src/sqlancer/questdb/gen/QuestDBAlterIndexGenerator.java b/src/sqlancer/questdb/gen/QuestDBAlterIndexGenerator.java new file mode 100644 index 000000000..85103738f --- /dev/null +++ b/src/sqlancer/questdb/gen/QuestDBAlterIndexGenerator.java @@ -0,0 +1,61 @@ +package sqlancer.questdb.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; +import sqlancer.questdb.QuestDBSchema.QuestDBColumn; +import sqlancer.questdb.QuestDBSchema.QuestDBDataType; +import sqlancer.questdb.QuestDBSchema.QuestDBTable; + +public final class QuestDBAlterIndexGenerator { + private QuestDBAlterIndexGenerator() { + } + + enum Action { + ADD_INDEX, DROP_INDEX + } + + public static SQLQueryAdapter getQuery(QuestDBGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + errors.add(" does not have a column with name \"rowid\""); + errors.add("Table does not contain column rowid referenced in alter statement"); + errors.add("cannot create index"); + errors.add("Index flag is only supported for SYMBOL"); + errors.add("Invalid column: "); + + StringBuilder sb = new StringBuilder("ALTER TABLE "); + + QuestDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + sb.append(" "); + + sb.append("ALTER COLUMN "); + + // We should always choose column with SYMBOL type + QuestDBColumn columnWithSymbolType = table + .getRandomColumnOrBailout(c -> c.getType().getPrimitiveDataType() == QuestDBDataType.SYMBOL); + + String columnName = columnWithSymbolType.getName(); + + sb.append(columnName); + sb.append(" "); + + Action action = Randomly.fromOptions(Action.values()); + switch (action) { + case ADD_INDEX: + sb.append("ADD INDEX"); + errors.add("already exists!"); + + break; + case DROP_INDEX: + sb.append("DROP INDEX"); + errors.add("Column is not indexed"); + break; + default: + throw new AssertionError("Unkown action:" + action); + } + + return new SQLQueryAdapter(sb.toString(), errors, true); + } +} diff --git a/src/sqlancer/questdb/gen/QuestDBExpressionGenerator.java b/src/sqlancer/questdb/gen/QuestDBExpressionGenerator.java new file mode 100644 index 000000000..cbb19e705 --- /dev/null +++ b/src/sqlancer/questdb/gen/QuestDBExpressionGenerator.java @@ -0,0 +1,203 @@ +package sqlancer.questdb.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.Randomly.StringGenerationStrategy; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; +import sqlancer.questdb.QuestDBSchema.QuestDBColumn; +import sqlancer.questdb.QuestDBSchema.QuestDBDataType; +import sqlancer.questdb.ast.QuestDBBinaryOperation; +import sqlancer.questdb.ast.QuestDBColumnReference; +import sqlancer.questdb.ast.QuestDBConstant; +import sqlancer.questdb.ast.QuestDBExpression; +import sqlancer.questdb.ast.QuestDBInOperation; +import sqlancer.questdb.ast.QuestDBUnaryPostfixOperation; +import sqlancer.questdb.ast.QuestDBUnaryPrefixOperation; + +public class QuestDBExpressionGenerator extends UntypedExpressionGenerator { + + private final QuestDBGlobalState globalState; + + public QuestDBExpressionGenerator(QuestDBGlobalState globalState) { + this.globalState = globalState; + } + + private enum Expression { + UNARY_POSTFIX, UNARY_PREFIX, BINARY_COMPARISON, BINARY_LOGICAL, BINARY_ARITHMETIC, IN + } + + @Override + public QuestDBExpression negatePredicate(QuestDBExpression predicate) { + return new QuestDBUnaryPrefixOperation(predicate, QuestDBUnaryPrefixOperator.NOT); + } + + @Override + public QuestDBExpression isNull(QuestDBExpression expr) { + return new QuestDBUnaryPostfixOperation(expr, QuestDBUnaryPostfixOperator.IS_NULL); + } + + @Override + public QuestDBExpression generateConstant() { + if (Randomly.getBooleanWithSmallProbability()) { + return QuestDBConstant.createNullConstant(); + } + QuestDBDataType type = QuestDBDataType.getRandomWithoutNull(); + switch (type) { + case INT: + return QuestDBConstant.createIntConstant(globalState.getRandomly().getInteger()); + case BOOLEAN: + return QuestDBConstant.createBooleanConstant(Randomly.getBoolean()); + case FLOAT: + return QuestDBConstant.createFloatConstant(globalState.getRandomly().getDouble()); + case SYMBOL: + StringGenerationStrategy strategy = Randomly.StringGenerationStrategy.ALPHANUMERIC; + return QuestDBConstant.createSymbolConstant(strategy.getString(globalState.getRandomly())); + // case CHAR: + // case DATE: + // case TIMESTAMP: + // throw new IgnoreMeException(); + default: + throw new AssertionError("Unknown type: " + type); + } + } + + @Override + protected QuestDBExpression generateExpression(int depth) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + return generateLeafNode(); + } + + List possibleOptions = new ArrayList<>(Arrays.asList(Expression.values())); + Expression expr = Randomly.fromList(possibleOptions); + + switch (expr) { + case UNARY_PREFIX: + return new QuestDBUnaryPrefixOperation(generateExpression(depth + 1), + QuestDBUnaryPrefixOperator.getRandom()); + case UNARY_POSTFIX: + return new QuestDBUnaryPostfixOperation(generateExpression(depth + 1), + QuestDBUnaryPostfixOperator.getRandom()); + case BINARY_COMPARISON: + return new QuestDBBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), + QuestDBBinaryComparisonOperator.getRandom()); + case BINARY_ARITHMETIC: + return new QuestDBBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), + QuestDBBinaryArithmeticOperator.getRandom()); + case BINARY_LOGICAL: + return new QuestDBBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), + QuestDBBinaryLogicalOperator.getRandom()); + case IN: + return new QuestDBInOperation(generateExpression(depth + 1), + generateExpressions(Randomly.smallNumber() + 1, depth + 1), Randomly.getBoolean()); + default: + throw new AssertionError("Expression generation failed, depth=" + depth); + } + } + + @Override + protected QuestDBExpression generateColumn() { + QuestDBColumn column = Randomly.fromList(columns); + return new QuestDBColumnReference(column); + } + + public enum QuestDBUnaryPostfixOperator implements Operator { + IS_NULL("IS NULL"), IS_NOT_NULL("IS NOT NULL"); + + private String textRepr; + + QuestDBUnaryPostfixOperator(String textRepr) { + this.textRepr = textRepr; + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + + public static QuestDBUnaryPostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + } + + public enum QuestDBUnaryPrefixOperator implements Operator { + + NOT("NOT"); + + private String textRepr; + + QuestDBUnaryPrefixOperator(String textRepr) { + this.textRepr = textRepr; + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + + public static QuestDBUnaryPrefixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + } + + public enum QuestDBBinaryLogicalOperator implements Operator { + + AND, OR; + + @Override + public String getTextRepresentation() { + return toString(); + } + + public static Operator getRandom() { + return Randomly.fromOptions(values()); + } + + } + + public enum QuestDBBinaryComparisonOperator implements Operator { + EQUALS("="), GREATER_THAN(">"), GREATER_THAN_EQUALS(">="), LESS_THAN("<"), SMALLER_THAN_EQUALS("<="), + NOT_EQUALS("!="), REGEX_POSIX("~"), REGEX_POSIT_NOT("!~"); + + private String textRepr; + + QuestDBBinaryComparisonOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + + } + + public enum QuestDBBinaryArithmeticOperator implements Operator { + CONCAT("||"), ADD("+"), SUB("-"), MULT("*"), DIV("/"), MOD("%"), AND("&"), OR("|"); // , LSHIFT("<<"), + // RSHIFT(">>"); + + private String textRepr; + + QuestDBBinaryArithmeticOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } +} diff --git a/src/sqlancer/questdb/gen/QuestDBInsertGenerator.java b/src/sqlancer/questdb/gen/QuestDBInsertGenerator.java new file mode 100644 index 000000000..754a0d2e5 --- /dev/null +++ b/src/sqlancer/questdb/gen/QuestDBInsertGenerator.java @@ -0,0 +1,60 @@ +package sqlancer.questdb.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.common.gen.AbstractInsertGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.questdb.QuestDBErrors; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; +import sqlancer.questdb.QuestDBSchema.QuestDBColumn; +import sqlancer.questdb.QuestDBSchema.QuestDBTable; +import sqlancer.questdb.QuestDBToStringVisitor; + +public class QuestDBInsertGenerator extends AbstractInsertGenerator { + + private final QuestDBGlobalState globalState; + + private final ExpectedErrors errors = new ExpectedErrors(); + + public QuestDBInsertGenerator(QuestDBGlobalState globalState) { + this.globalState = globalState; + } + + private SQLQueryAdapter generate() { + sb.append("INSERT INTO "); + QuestDBTable table = globalState.getSchema().getRandomTable(); + List columns = table.getRandomNonEmptyColumnSubset(); + sb.append(table.getName()); + sb.append("("); + sb.append(columns.stream().map(AbstractTableColumn::getName).collect(Collectors.joining(", "))); + sb.append(")"); + sb.append(" VALUES "); + insertColumns(columns); + QuestDBErrors.addInsertErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } + + public static SQLQueryAdapter getQuery(QuestDBGlobalState globalState) { + return new QuestDBInsertGenerator(globalState).generate(); + } + + @Override + protected void insertColumns(List columns) { + sb.append("("); + for (int nrColumn = 0; nrColumn < columns.size(); nrColumn++) { + if (nrColumn != 0) { + sb.append(", "); + } + insertValue(columns.get(nrColumn)); + } + sb.append(")"); + } + + @Override + protected void insertValue(QuestDBColumn questDBColumn) { + sb.append(QuestDBToStringVisitor.asString(new QuestDBExpressionGenerator(globalState).generateConstant())); + } +} diff --git a/src/sqlancer/questdb/gen/QuestDBTableGenerator.java b/src/sqlancer/questdb/gen/QuestDBTableGenerator.java new file mode 100644 index 000000000..a308b8d10 --- /dev/null +++ b/src/sqlancer/questdb/gen/QuestDBTableGenerator.java @@ -0,0 +1,53 @@ +package sqlancer.questdb.gen; + +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; +import sqlancer.questdb.QuestDBSchema.QuestDBColumn; +import sqlancer.questdb.QuestDBSchema.QuestDBCompositeDataType; + +public class QuestDBTableGenerator { + + public SQLQueryAdapter getQuery(QuestDBGlobalState globalState, @Nullable String tableName) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + String name = tableName; + if (tableName == null) { + name = globalState.getSchema().getFreeTableName(); + } + sb.append("CREATE TABLE "); + if (Randomly.getBoolean()) { + sb.append("IF NOT EXISTS "); + } + sb.append(name); + sb.append("("); + List columns = getNewColumns(); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(columns.get(i).getName()); + sb.append(" "); + sb.append(columns.get(i).getType()); + } + sb.append(")"); + sb.append(";"); + errors.add("table already exists"); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + + private static List getNewColumns() { + List columns = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + String columnName = String.format("c%d", i); + QuestDBCompositeDataType columnType = QuestDBCompositeDataType.getRandomWithoutNull(); + columns.add(new QuestDBColumn(columnName, columnType, false)); + } + return columns; + } +} diff --git a/src/sqlancer/questdb/gen/QuestDBTruncateGenerator.java b/src/sqlancer/questdb/gen/QuestDBTruncateGenerator.java new file mode 100644 index 000000000..19c48f9f7 --- /dev/null +++ b/src/sqlancer/questdb/gen/QuestDBTruncateGenerator.java @@ -0,0 +1,22 @@ +package sqlancer.questdb.gen; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.questdb.QuestDBErrors; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; +import sqlancer.questdb.QuestDBSchema.QuestDBTable; + +public final class QuestDBTruncateGenerator { + private QuestDBTruncateGenerator() { + + } + + public static SQLQueryAdapter generate(QuestDBGlobalState globalState) { + StringBuilder sb = new StringBuilder("TRUNCATE TABLE "); + ExpectedErrors errors = new ExpectedErrors(); + QuestDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + QuestDBErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } +} diff --git a/src/sqlancer/questdb/gen/QuestDBUpdateGenerator.java b/src/sqlancer/questdb/gen/QuestDBUpdateGenerator.java new file mode 100644 index 000000000..cc7fba1df --- /dev/null +++ b/src/sqlancer/questdb/gen/QuestDBUpdateGenerator.java @@ -0,0 +1,16 @@ +package sqlancer.questdb.gen; + +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; + +public final class QuestDBUpdateGenerator { + + private QuestDBUpdateGenerator() { + } + + public static SQLQueryAdapter getQuery(QuestDBGlobalState globalState) { + // TODO + return null; + } + +} diff --git a/src/sqlancer/questdb/gen/QuestDBViewGenerator.java b/src/sqlancer/questdb/gen/QuestDBViewGenerator.java new file mode 100644 index 000000000..0ff4d232b --- /dev/null +++ b/src/sqlancer/questdb/gen/QuestDBViewGenerator.java @@ -0,0 +1,14 @@ +package sqlancer.questdb.gen; + +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; + +public final class QuestDBViewGenerator { + private QuestDBViewGenerator() { + } + + public static SQLQueryAdapter generate(QuestDBGlobalState globalState) { + // TODO + return null; + } +} diff --git a/src/sqlancer/questdb/test/QuestDBQueryPartitioningBase.java b/src/sqlancer/questdb/test/QuestDBQueryPartitioningBase.java new file mode 100644 index 000000000..0716e9ccc --- /dev/null +++ b/src/sqlancer/questdb/test/QuestDBQueryPartitioningBase.java @@ -0,0 +1,71 @@ +package sqlancer.questdb.test; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.questdb.QuestDBErrors; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; +import sqlancer.questdb.QuestDBSchema; +import sqlancer.questdb.QuestDBSchema.QuestDBColumn; +import sqlancer.questdb.QuestDBSchema.QuestDBTable; +import sqlancer.questdb.ast.QuestDBColumnReference; +import sqlancer.questdb.ast.QuestDBExpression; +import sqlancer.questdb.ast.QuestDBSelect; +import sqlancer.questdb.ast.QuestDBTableReference; +import sqlancer.questdb.gen.QuestDBExpressionGenerator; + +public class QuestDBQueryPartitioningBase + extends TernaryLogicPartitioningOracleBase + implements TestOracle { + + QuestDBSchema s; + QuestDBTable targetTable; + QuestDBExpressionGenerator gen; + QuestDBSelect select; + + protected QuestDBQueryPartitioningBase(QuestDBGlobalState state) { + super(state); + QuestDBErrors.addExpressionErrors(errors); + } + + List generateFetchColumns() { + List columns = new ArrayList<>(); + if (Randomly.getBoolean()) { + columns.add(new QuestDBColumnReference(new QuestDBColumn("*", null, false))); + } else { + columns = Randomly.nonEmptySubset(targetTable.getColumns()).stream().map(c -> new QuestDBColumnReference(c)) + .collect(Collectors.toList()); + } + return columns; + } + + @Override + protected ExpressionGenerator getGen() { + return gen; + } + + @Override + public void check() throws SQLException { + s = state.getSchema(); + // Only return one table instead of multiple tables, which is regarded as illegal by QuestDB + // e.g. "SELECT * FROM t0, t1;" + targetTable = s.getRandomTable(); + gen = new QuestDBExpressionGenerator(state).setColumns(targetTable.getColumns()); + initializeTernaryPredicateVariants(); + select = new QuestDBSelect(); + select.setFetchColumns(generateFetchColumns()); + List tables = new ArrayList<>(); + tables.add(targetTable); + List tableList = tables.stream().map(t -> new QuestDBTableReference(t)) + .collect(Collectors.toList()); + // Ignore JOINs for now + select.setFromList(new ArrayList<>(tableList)); + select.setWhereClause(null); + } +} diff --git a/src/sqlancer/questdb/test/QuestDBQueryPartitioningWhereTester.java b/src/sqlancer/questdb/test/QuestDBQueryPartitioningWhereTester.java new file mode 100644 index 000000000..ed89e7136 --- /dev/null +++ b/src/sqlancer/questdb/test/QuestDBQueryPartitioningWhereTester.java @@ -0,0 +1,41 @@ +package sqlancer.questdb.test; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.questdb.QuestDBErrors; +import sqlancer.questdb.QuestDBProvider.QuestDBGlobalState; +import sqlancer.questdb.QuestDBToStringVisitor; + +public class QuestDBQueryPartitioningWhereTester extends QuestDBQueryPartitioningBase { + public QuestDBQueryPartitioningWhereTester(QuestDBGlobalState state) { + super(state); + QuestDBErrors.addGroupByErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + select.setWhereClause(null); + String originalQueryString = QuestDBToStringVisitor.asString(select); + + List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + + // Ignore OrderBy for now + + select.setWhereClause(predicate); + String firstQueryString = QuestDBToStringVisitor.asString(select); + select.setWhereClause(negatedPredicate); + String secondQueryString = QuestDBToStringVisitor.asString(select); + select.setWhereClause(isNullPredicate); + String thirdQueryString = QuestDBToStringVisitor.asString(select); + + List combinedString = new ArrayList<>(); + List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, + thirdQueryString, combinedString, false, state, errors); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state, ComparatorHelper::canonicalizeResultValue); + } +} diff --git a/src/sqlancer/spark/SparkErrors.java b/src/sqlancer/spark/SparkErrors.java new file mode 100644 index 000000000..a3a96061f --- /dev/null +++ b/src/sqlancer/spark/SparkErrors.java @@ -0,0 +1,65 @@ +package sqlancer.spark; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class SparkErrors { + + private SparkErrors() { + } + + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("cannot resolve"); + errors.add("AnalysisException"); + errors.add("data type mismatch"); + errors.add("undefined function"); + errors.add("mismatched input"); + errors.add("due to data type mismatch"); + + // --- Invalid Literals + errors.add("The value of the typed literal"); + + errors.add("DATATYPE_MISMATCH"); + errors.add("cannot be cast to"); + + errors.add("Overflow"); + errors.add("Divide by zero"); // Common if spark.sql.ansi.enabled is true + errors.add("division by zero"); + + // --- Group By / Aggregation errors --- + errors.add("grouping expressions"); + errors.add("expression is neither present in the group by"); + errors.add("is not a valid grouping expression"); + errors.add("is not contained in either an aggregate function or the GROUP BY clause"); + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("not enough data columns"); + errors.add("cannot write to"); + errors.add("incompatible types"); + errors.add("too many data columns"); + errors.add("cannot be cast to"); + errors.add("Error running query"); + errors.add("The value of the typed literal"); + errors.add("Cannot safely cast"); // Found in logs: Decimal -> Date + errors.add("AnalysisException"); // Spark throws this for almost all insert failures + + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); + } +} diff --git a/src/sqlancer/spark/SparkGlobalState.java b/src/sqlancer/spark/SparkGlobalState.java new file mode 100644 index 000000000..d78c737e5 --- /dev/null +++ b/src/sqlancer/spark/SparkGlobalState.java @@ -0,0 +1,11 @@ +package sqlancer.spark; + +import sqlancer.SQLGlobalState; + +public class SparkGlobalState extends SQLGlobalState { + + @Override + protected SparkSchema readSchema() throws Exception { + return SparkSchema.fromConnection(getConnection(), getDatabaseName()); + } +} diff --git a/src/sqlancer/spark/SparkOptions.java b/src/sqlancer/spark/SparkOptions.java new file mode 100644 index 000000000..7b347ceef --- /dev/null +++ b/src/sqlancer/spark/SparkOptions.java @@ -0,0 +1,43 @@ +package sqlancer.spark; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; +import sqlancer.OracleFactory; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.spark.gen.SparkExpressionGenerator; + +@Parameters(separators = "=", commandDescription = "Spark SQL (default port: " + SparkOptions.DEFAULT_PORT + + ", default host: " + SparkOptions.DEFAULT_HOST + ")") +public class SparkOptions implements DBMSSpecificOptions { + public static final String DEFAULT_HOST = "localhost"; + public static final int DEFAULT_PORT = 10000; + + @Parameter(names = "--oracle") + public List oracle = Arrays.asList(SparkOracleFactory.TLPWhere); + + public enum SparkOracleFactory implements OracleFactory { + TLPWhere { + @Override + public TestOracle create(SparkGlobalState globalState) throws SQLException { + SparkExpressionGenerator gen = new SparkExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(SparkErrors.getExpressionErrors()) + .build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }; + } + + @Override + public List getTestOracleFactory() { + return oracle; + } +} diff --git a/src/sqlancer/spark/SparkProvider.java b/src/sqlancer/spark/SparkProvider.java new file mode 100644 index 000000000..817a92471 --- /dev/null +++ b/src/sqlancer/spark/SparkProvider.java @@ -0,0 +1,123 @@ +package sqlancer.spark; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.spark.gen.SparkInsertGenerator; +import sqlancer.spark.gen.SparkTableGenerator; + +@AutoService(DatabaseProvider.class) +public class SparkProvider extends SQLProviderAdapter { + + public SparkProvider() { + super(SparkGlobalState.class, SparkOptions.class); + } + + public enum Action implements AbstractAction { + INSERT(SparkInsertGenerator::getQuery); + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(SparkGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + + private static int mapActions(SparkGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + switch (a) { + case INSERT: + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + default: + throw new AssertionError(a); + } + } + + @Override + public void generateDatabase(SparkGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { + boolean success; + do { + String tableName = globalState.getSchema().getFreeTableName(); + SQLQueryAdapter qt = SparkTableGenerator.generate(globalState, tableName); + success = globalState.executeStatement(qt); + } while (!success); + } + + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + SparkProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public SQLConnection createDatabase(SparkGlobalState globalState) throws SQLException { + String username = globalState.getOptions().getUserName(); + String password = globalState.getOptions().getPassword(); + String host = globalState.getOptions().getHost(); + int port = globalState.getOptions().getPort(); + + if (host == null) { + host = SparkOptions.DEFAULT_HOST; + } + if (port == MainOptions.NO_SET_PORT) { + port = SparkOptions.DEFAULT_PORT; + } + + String databaseName = globalState.getDatabaseName(); + + // Spark uses the Hive driver for JDBC usually + String url = String.format("jdbc:hive2://%s:%d/%s", host, port, "default"); + + // Connect to default to create the fuzzing DB + Connection con = DriverManager.getConnection(url, username, password); + try (Statement s = con.createStatement()) { + s.execute("DROP DATABASE IF EXISTS " + databaseName + " CASCADE"); + } + try (Statement s = con.createStatement()) { + s.execute("CREATE DATABASE " + databaseName); + } + con.close(); + + // Connect to the specific fuzzing DB + con = DriverManager.getConnection(String.format("jdbc:hive2://%s:%d/%s", host, port, databaseName), username, + password); + try (Statement s = con.createStatement()) { + // This allows casting things like BOOLEAN to DATE/TIMESTAMP, which the + // generator loves to do. + s.execute("SET spark.sql.ansi.enabled=false"); + } + return new SQLConnection(con); + } + + @Override + public String getDBMSName() { + return "spark"; + } +} diff --git a/src/sqlancer/spark/SparkSchema.java b/src/sqlancer/spark/SparkSchema.java new file mode 100644 index 000000000..9b3666916 --- /dev/null +++ b/src/sqlancer/spark/SparkSchema.java @@ -0,0 +1,122 @@ +package sqlancer.spark; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; +import sqlancer.spark.SparkSchema.SparkTable; + +public class SparkSchema extends AbstractSchema { + + public enum SparkDataType { + STRING, INTEGER, DOUBLE, BOOLEAN, TIMESTAMP, DATE; + + public static SparkDataType getRandomType() { + return Randomly.fromList(Arrays.asList(values())); + } + } + + public static class SparkColumn extends AbstractTableColumn { + public SparkColumn(String name, SparkTable table, SparkDataType type) { + super(name, table, type); + } + } + + public static class SparkTables extends AbstractTables { + public SparkTables(List tables) { + super(tables); + } + } + + public static class SparkTable extends AbstractRelationalTable { + public SparkTable(String name, List columns, boolean isView) { + super(name, columns, Collections.emptyList(), isView); + } + } + + public SparkSchema(List databaseTables) { + super(databaseTables); + } + + public static SparkSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + List databaseTables = new ArrayList<>(); + List tableNames = getTableNames(con); + for (String tableName : tableNames) { + List databaseColumns = getTableColumns(con, tableName); + boolean isView = matchesViewName(tableName); + SparkTable t = new SparkTable(tableName, databaseColumns, isView); + for (SparkColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + } + return new SparkSchema(databaseTables); + } + + private static List getTableNames(SQLConnection con) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + ResultSet tableRs = s.executeQuery("SHOW TABLES"); + while (tableRs.next()) { + // Spark SHOW TABLES output: database, tableName, isTemporary + String tableName = tableRs.getString("tableName"); + tableNames.add(tableName); + } + } + return tableNames; + } + + private static List getTableColumns(SQLConnection con, String tableName) throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String.format("DESCRIBE %s", tableName))) { + while (rs.next()) { + String columnName = rs.getString("col_name"); + String dataType = rs.getString("data_type"); + // Filter out Spark partition info or comments usually at bottom of describe + if (columnName.startsWith("#") || columnName.isEmpty()) { + continue; + } + + columns.add(new SparkColumn(columnName, null, getColumnType(dataType))); + } + } + } + return columns; + } + + private static SparkDataType getColumnType(String typeString) { + String upper = typeString.toUpperCase(); + if (upper.startsWith("STRING") || upper.startsWith("VARCHAR") || upper.startsWith("CHAR")) { + return SparkDataType.STRING; + } + if (upper.startsWith("INT") || upper.startsWith("BIGINT") || upper.startsWith("SMALLINT")) { + return SparkDataType.INTEGER; + } + if (upper.startsWith("DOUBLE") || upper.startsWith("FLOAT") || upper.startsWith("DECIMAL")) { + return SparkDataType.DOUBLE; + } + if (upper.startsWith("BOOLEAN")) { + return SparkDataType.BOOLEAN; + } + if (upper.startsWith("TIMESTAMP")) { + return SparkDataType.TIMESTAMP; + } + if (upper.startsWith("DATE")) { + return SparkDataType.DATE; + } + return SparkDataType.STRING; // Fallback + } + +} diff --git a/src/sqlancer/spark/SparkToStringVisitor.java b/src/sqlancer/spark/SparkToStringVisitor.java new file mode 100644 index 000000000..0777c86a6 --- /dev/null +++ b/src/sqlancer/spark/SparkToStringVisitor.java @@ -0,0 +1,121 @@ +package sqlancer.spark; + +import sqlancer.common.ast.newast.NewToStringVisitor; +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.spark.ast.SparkCastOperation; +import sqlancer.spark.ast.SparkConstant; +import sqlancer.spark.ast.SparkExpression; +import sqlancer.spark.ast.SparkJoin; +import sqlancer.spark.ast.SparkSelect; + +public class SparkToStringVisitor extends NewToStringVisitor { + + @Override + public void visitSpecific(SparkExpression expr) { + if (expr instanceof SparkConstant) { + visit((SparkConstant) expr); + } else if (expr instanceof SparkSelect) { + visit((SparkSelect) expr); + } else if (expr instanceof SparkJoin) { + visit((SparkJoin) expr); + } else if (expr instanceof SparkCastOperation) { + visit((SparkCastOperation) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + private void visit(SparkConstant constant) { + sb.append(constant.toString()); + } + + private void visit(SparkSelect select) { + sb.append("SELECT "); + if (select.isDistinct()) { + sb.append("DISTINCT "); + } + visit(select.getFetchColumns()); + sb.append(" FROM "); + visit(select.getFromList()); + if (!select.getFromList().isEmpty() && !select.getJoinList().isEmpty()) { + sb.append(", "); + } + if (!select.getJoinList().isEmpty()) { + visit(select.getJoinList()); + } + if (select.getWhereClause() != null) { + sb.append(" WHERE "); + visit(select.getWhereClause()); + } + if (!select.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(select.getGroupByExpressions()); + } + if (select.getHavingClause() != null) { + sb.append(" HAVING "); + visit(select.getHavingClause()); + } + if (!select.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(select.getOrderByClauses()); + } + if (select.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(select.getLimitClause()); + } + // Spark supports OFFSET, though strictly usually with LIMIT or in newer + // versions + if (select.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(select.getOffsetClause()); + } + } + + private void visit(SparkJoin join) { + switch (join.getJoinType()) { + case INNER: + sb.append(" INNER JOIN "); + break; + case LEFT_OUTER: + sb.append(" LEFT JOIN "); + break; + case RIGHT_OUTER: + sb.append(" RIGHT JOIN "); + break; + case FULL_OUTER: + sb.append(" FULL JOIN "); + break; + case LEFT_SEMI: + sb.append(" LEFT SEMI JOIN "); + break; + // Spark also supports LEFT ANTI, which Hive might lack in some older versions + case LEFT_ANTI: + sb.append(" LEFT ANTI JOIN "); + break; + case CROSS: + sb.append(" CROSS JOIN "); + break; + default: + throw new UnsupportedOperationException("Join type not supported in Spark visitor: " + join.getJoinType()); + } + visit((TableReferenceNode) join.getRightTable()); + if (join.getOnClause() != null) { + sb.append(" ON "); + visit(join.getOnClause()); + } + } + + private void visit(SparkCastOperation cast) { + sb.append("CAST("); + visit(cast.getExpression()); + sb.append(" AS "); + sb.append(cast.getType()); + sb.append(")"); + } + + public static String asString(SparkExpression expr) { + SparkToStringVisitor visitor = new SparkToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } +} diff --git a/src/sqlancer/spark/ast/SparkBetweenOperation.java b/src/sqlancer/spark/ast/SparkBetweenOperation.java new file mode 100644 index 000000000..59297ba8f --- /dev/null +++ b/src/sqlancer/spark/ast/SparkBetweenOperation.java @@ -0,0 +1,10 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.NewBetweenOperatorNode; + +public class SparkBetweenOperation extends NewBetweenOperatorNode implements SparkExpression { + + public SparkBetweenOperation(SparkExpression left, SparkExpression middle, SparkExpression right, boolean isTrue) { + super(left, middle, right, isTrue); + } +} diff --git a/src/sqlancer/spark/ast/SparkBinaryOperation.java b/src/sqlancer/spark/ast/SparkBinaryOperation.java new file mode 100644 index 000000000..ef3347018 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkBinaryOperation.java @@ -0,0 +1,11 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class SparkBinaryOperation extends NewBinaryOperatorNode implements SparkExpression { + + public SparkBinaryOperation(SparkExpression left, SparkExpression right, Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/spark/ast/SparkCaseOperation.java b/src/sqlancer/spark/ast/SparkCaseOperation.java new file mode 100644 index 000000000..995fd7f52 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkCaseOperation.java @@ -0,0 +1,13 @@ +package sqlancer.spark.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewCaseOperatorNode; + +public class SparkCaseOperation extends NewCaseOperatorNode implements SparkExpression { + + public SparkCaseOperation(SparkExpression switchCondition, List conditions, + List expressions, SparkExpression elseExpr) { + super(switchCondition, conditions, expressions, elseExpr); + } +} diff --git a/src/sqlancer/spark/ast/SparkCastOperation.java b/src/sqlancer/spark/ast/SparkCastOperation.java new file mode 100644 index 000000000..547551285 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkCastOperation.java @@ -0,0 +1,25 @@ +package sqlancer.spark.ast; + +import sqlancer.spark.SparkSchema.SparkDataType; + +public class SparkCastOperation implements SparkExpression { + + private final SparkExpression expression; + private final SparkDataType type; + + public SparkCastOperation(SparkExpression expression, SparkDataType type) { + if (expression == null) { + throw new AssertionError(); + } + this.expression = expression; + this.type = type; + } + + public SparkExpression getExpression() { + return expression; + } + + public SparkDataType getType() { + return type; + } +} diff --git a/src/sqlancer/spark/ast/SparkColumnReference.java b/src/sqlancer/spark/ast/SparkColumnReference.java new file mode 100644 index 000000000..ccd1b7855 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkColumnReference.java @@ -0,0 +1,11 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.spark.SparkSchema.SparkColumn; + +public class SparkColumnReference extends ColumnReferenceNode implements SparkExpression { + + public SparkColumnReference(SparkColumn column) { + super(column); + } +} diff --git a/src/sqlancer/spark/ast/SparkConstant.java b/src/sqlancer/spark/ast/SparkConstant.java new file mode 100644 index 000000000..84a397624 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkConstant.java @@ -0,0 +1,194 @@ +package sqlancer.spark.ast; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; + +public abstract class SparkConstant implements SparkExpression { + + public boolean isNull() { + return false; + } + + public static class SparkNullConstant extends SparkConstant { + + @Override + public boolean isNull() { + return true; + } + + @Override + public String toString() { + return "NULL"; + } + } + + public static class SparkIntConstant extends SparkConstant { + + private final long value; + + public SparkIntConstant(long value) { + this.value = value; + } + + public long getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static class SparkDoubleConstant extends SparkConstant { + + private final double value; + + public SparkDoubleConstant(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "CAST('Infinity' AS DOUBLE)"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "CAST('-Infinity' AS DOUBLE)"; + } else if (Double.isNaN(value)) { + return "CAST('NaN' AS DOUBLE)"; + } + return String.valueOf(value); + } + } + + public static class SparkDecimalConstant extends SparkConstant { + + private final BigDecimal value; + + public SparkDecimalConstant(BigDecimal value) { + this.value = value; + } + + public BigDecimal getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static class SparkTimestampConstant extends SparkConstant { + + private final String textRepr; + + public SparkTimestampConstant(long value) { + Timestamp timestamp = new Timestamp(value); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); // Spark prefers full timestamp + this.textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("TIMESTAMP '%s'", textRepr); + } + } + + public static class SparkDateConstant extends SparkConstant { + + private final String textRepr; + + public SparkDateConstant(long value) { + Timestamp timestamp = new Timestamp(value); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + this.textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("DATE '%s'", textRepr); + } + } + + public static class SparkStringConstant extends SparkConstant { + + private final String value; + + public SparkStringConstant(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "'" + value.replace("'", "''").replace("\\", "\\\\") + "'"; + } + } + + public static class SparkBooleanConstant extends SparkConstant { + + private final boolean value; + + public SparkBooleanConstant(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static SparkConstant createNullConstant() { + return new SparkNullConstant(); + } + + public static SparkConstant createIntConstant(long value) { + return new SparkIntConstant(value); + } + + public static SparkConstant createDoubleConstant(double value) { + return new SparkDoubleConstant(value); + } + + public static SparkConstant createDecimalConstant(BigDecimal value) { + return new SparkDecimalConstant(value); + } + + public static SparkConstant createTimestampConstant(long value) { + return new SparkTimestampConstant(value); + } + + public static SparkConstant createDateConstant(long value) { + return new SparkDateConstant(value); + } + + public static SparkConstant createStringConstant(String value) { + return new SparkStringConstant(value); + } + + public static SparkConstant createBooleanConstant(boolean value) { + return new SparkBooleanConstant(value); + } +} diff --git a/src/sqlancer/spark/ast/SparkExpression.java b/src/sqlancer/spark/ast/SparkExpression.java new file mode 100644 index 000000000..3872ccfda --- /dev/null +++ b/src/sqlancer/spark/ast/SparkExpression.java @@ -0,0 +1,7 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.spark.SparkSchema.SparkColumn; + +public interface SparkExpression extends Expression { +} diff --git a/src/sqlancer/spark/ast/SparkFunction.java b/src/sqlancer/spark/ast/SparkFunction.java new file mode 100644 index 000000000..d5740ee36 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkFunction.java @@ -0,0 +1,13 @@ +package sqlancer.spark.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class SparkFunction extends NewFunctionNode implements SparkExpression { + + public SparkFunction(List args, F func) { + super(args, func); + } + +} diff --git a/src/sqlancer/spark/ast/SparkInOperation.java b/src/sqlancer/spark/ast/SparkInOperation.java new file mode 100644 index 000000000..430d9b5c2 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkInOperation.java @@ -0,0 +1,12 @@ +package sqlancer.spark.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; + +public class SparkInOperation extends NewInOperatorNode implements SparkExpression { + + public SparkInOperation(SparkExpression left, List right, boolean isNegated) { + super(left, right, isNegated); + } +} diff --git a/src/sqlancer/spark/ast/SparkJoin.java b/src/sqlancer/spark/ast/SparkJoin.java new file mode 100644 index 000000000..a59eaff48 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkJoin.java @@ -0,0 +1,46 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.Join; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkTable; + +public class SparkJoin implements SparkExpression, Join { + + private final SparkTableReference leftTable; + private final SparkTableReference rightTable; + private final JoinType joinType; + private SparkExpression onClause; + + public enum JoinType { + INNER, LEFT_OUTER, RIGHT_OUTER, FULL_OUTER, LEFT_SEMI, LEFT_ANTI, CROSS; + } + + public SparkJoin(SparkTableReference leftTable, SparkTableReference rightTable, JoinType joinType, + SparkExpression onClause) { + this.leftTable = leftTable; + this.rightTable = rightTable; + this.joinType = joinType; + this.onClause = onClause; + } + + public SparkTableReference getLeftTable() { + return leftTable; + } + + public SparkTableReference getRightTable() { + return rightTable; + } + + public JoinType getJoinType() { + return joinType; + } + + public SparkExpression getOnClause() { + return onClause; + } + + @Override + public void setOnClause(SparkExpression onClause) { + this.onClause = onClause; + } +} diff --git a/src/sqlancer/spark/ast/SparkOrderingTerm.java b/src/sqlancer/spark/ast/SparkOrderingTerm.java new file mode 100644 index 000000000..870c8239b --- /dev/null +++ b/src/sqlancer/spark/ast/SparkOrderingTerm.java @@ -0,0 +1,10 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.NewOrderingTerm; + +public class SparkOrderingTerm extends NewOrderingTerm implements SparkExpression { + + public SparkOrderingTerm(SparkExpression expr, Ordering ordering) { + super(expr, ordering); + } +} diff --git a/src/sqlancer/spark/ast/SparkSelect.java b/src/sqlancer/spark/ast/SparkSelect.java new file mode 100644 index 000000000..8b59f5513 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkSelect.java @@ -0,0 +1,42 @@ +package sqlancer.spark.ast; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkTable; +import sqlancer.spark.SparkToStringVisitor; + +public class SparkSelect extends SelectBase + implements Select, SparkExpression { + + private boolean isDistinct; + + public void setDistinct(boolean isDistinct) { + this.isDistinct = isDistinct; + } + + public boolean isDistinct() { + return isDistinct; + } + + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (SparkExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (SparkJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return SparkToStringVisitor.asString(this); + } + +} diff --git a/src/sqlancer/spark/ast/SparkTableReference.java b/src/sqlancer/spark/ast/SparkTableReference.java new file mode 100644 index 000000000..5bcbb5d03 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkTableReference.java @@ -0,0 +1,13 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.spark.SparkSchema; + +public class SparkTableReference extends TableReferenceNode + implements SparkExpression { + + public SparkTableReference(SparkSchema.SparkTable table) { + super(table); + } + +} diff --git a/src/sqlancer/spark/ast/SparkUnaryPostfixOperation.java b/src/sqlancer/spark/ast/SparkUnaryPostfixOperation.java new file mode 100644 index 000000000..3dd9d28e2 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkUnaryPostfixOperation.java @@ -0,0 +1,13 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; + +public class SparkUnaryPostfixOperation extends NewUnaryPostfixOperatorNode + implements SparkExpression { + + public SparkUnaryPostfixOperation(SparkExpression expr, Operator op) { + super(expr, op); + } + +} diff --git a/src/sqlancer/spark/ast/SparkUnaryPrefixOperation.java b/src/sqlancer/spark/ast/SparkUnaryPrefixOperation.java new file mode 100644 index 000000000..5c1a8e4c6 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkUnaryPrefixOperation.java @@ -0,0 +1,12 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; + +public class SparkUnaryPrefixOperation extends NewUnaryPrefixOperatorNode implements SparkExpression { + + public SparkUnaryPrefixOperation(SparkExpression expr, Operator op) { + super(expr, op); + } + +} diff --git a/src/sqlancer/spark/gen/SparkExpressionGenerator.java b/src/sqlancer/spark/gen/SparkExpressionGenerator.java new file mode 100644 index 000000000..3708f314a --- /dev/null +++ b/src/sqlancer/spark/gen/SparkExpressionGenerator.java @@ -0,0 +1,336 @@ +package sqlancer.spark.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewOrderingTerm.Ordering; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; +import sqlancer.spark.SparkGlobalState; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkDataType; +import sqlancer.spark.SparkSchema.SparkTable; +import sqlancer.spark.ast.SparkBetweenOperation; +import sqlancer.spark.ast.SparkBinaryOperation; +import sqlancer.spark.ast.SparkCaseOperation; +import sqlancer.spark.ast.SparkCastOperation; +import sqlancer.spark.ast.SparkColumnReference; +import sqlancer.spark.ast.SparkConstant; +import sqlancer.spark.ast.SparkExpression; +import sqlancer.spark.ast.SparkFunction; +import sqlancer.spark.ast.SparkInOperation; +import sqlancer.spark.ast.SparkJoin; +import sqlancer.spark.ast.SparkOrderingTerm; +import sqlancer.spark.ast.SparkSelect; +import sqlancer.spark.ast.SparkTableReference; +import sqlancer.spark.ast.SparkUnaryPostfixOperation; +import sqlancer.spark.ast.SparkUnaryPrefixOperation; + +public class SparkExpressionGenerator extends UntypedExpressionGenerator + implements TLPWhereGenerator { + + private final SparkGlobalState globalState; + private List tables; + + private enum Expression { + UNARY_PREFIX, UNARY_POSTFIX, BINARY_COMPARISON, BINARY_LOGICAL, BINARY_ARITHMETIC, CAST, FUNC, BETWEEN, IN, + CASE; + } + + public SparkExpressionGenerator(SparkGlobalState globalState) { + this.globalState = globalState; + } + + @Override + public SparkExpression negatePredicate(SparkExpression predicate) { + return new SparkUnaryPrefixOperation(predicate, SparkUnaryPrefixOperator.NOT); + } + + @Override + public SparkExpression isNull(SparkExpression expr) { + return new SparkUnaryPostfixOperation(expr, SparkUnaryPostfixOperator.IS_NULL); + } + + @Override + protected SparkExpression generateExpression(int depth) { + return generateExpressionInternal(depth); + } + + private SparkExpression generateExpressionInternal(int depth) throws AssertionError { + if (depth >= globalState.getOptions().getMaxExpressionDepth() + || Randomly.getBooleanWithRatherLowProbability()) { + return generateLeafNode(); + } + if (allowAggregates && Randomly.getBooleanWithRatherLowProbability()) { + allowAggregates = false; // aggregate function calls cannot be nested + SparkAggregateFunction aggregate = SparkAggregateFunction.getRandom(); + return new SparkFunction<>(generateExpressions(aggregate.getNrArgs(), depth + 1), aggregate); + } + + List possibleOptions = new ArrayList<>(Arrays.asList(Expression.values())); + Expression expr = Randomly.fromList(possibleOptions); + + switch (expr) { + case UNARY_PREFIX: + return new SparkUnaryPrefixOperation(generateExpression(depth + 1), SparkUnaryPrefixOperator.getRandom()); + case UNARY_POSTFIX: + return new SparkUnaryPostfixOperation(generateExpression(depth + 1), SparkUnaryPostfixOperator.getRandom()); + case BINARY_COMPARISON: + Operator op = SparkBinaryComparisonOperator.getRandom(); + return new SparkBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); + case BINARY_LOGICAL: + op = SparkBinaryLogicalOperator.getRandom(); + return new SparkBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); + case BINARY_ARITHMETIC: + return new SparkBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), + SparkBinaryArithmeticOperator.getRandom()); + case CAST: + return new SparkCastOperation(generateExpression(depth + 1), SparkDataType.getRandomType()); + case FUNC: + SparkFunc func = SparkFunc.getRandom(); + return new SparkFunction<>(generateExpressions(func.getNrArgs()), func); + case BETWEEN: + return new SparkBetweenOperation(generateExpression(depth + 1), generateExpression(depth + 1), + generateExpression(depth + 1), Randomly.getBoolean()); + case IN: + return new SparkInOperation(generateExpression(depth + 1), + generateExpressions(Randomly.smallNumber() + 1, depth + 1), Randomly.getBoolean()); + case CASE: + int nr = Randomly.smallNumber() + 1; + return new SparkCaseOperation(generateExpression(depth + 1), generateExpressions(nr, depth + 1), + generateExpressions(nr, depth + 1), generateExpression(depth + 1)); + default: + throw new AssertionError(expr); + } + } + + @Override + public SparkExpression generateConstant() { + if (Randomly.getBooleanWithRatherLowProbability()) { + return SparkConstant.createNullConstant(); + } + SparkDataType[] values = SparkDataType.values(); + SparkDataType constantType = Randomly.fromOptions(values); + switch (constantType) { + case STRING: + return SparkConstant.createStringConstant(globalState.getRandomly().getString()); + case INTEGER: + return SparkConstant.createIntConstant(globalState.getRandomly().getInteger()); + case DOUBLE: + return SparkConstant.createDoubleConstant(globalState.getRandomly().getDouble()); + case BOOLEAN: + return SparkConstant.createBooleanConstant(Randomly.getBoolean()); + case TIMESTAMP: + return SparkConstant.createTimestampConstant(globalState.getRandomly().getInteger()); + case DATE: + return SparkConstant.createDateConstant(globalState.getRandomly().getInteger()); + default: + throw new AssertionError(constantType); + } + } + + @Override + protected SparkExpression generateColumn() { + SparkColumn column = Randomly.fromList(columns); + return new SparkColumnReference(column); + } + + @Override + public List generateOrderBys() { + List expr = super.generateOrderBys(); + List newExpr = new ArrayList<>(expr.size()); + for (SparkExpression curExpr : expr) { + if (Randomly.getBoolean()) { + curExpr = new SparkOrderingTerm(curExpr, Ordering.getRandom()); + } + newExpr.add(curExpr); + } + return newExpr; + } + + @Override + public SparkExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + return this; + } + + @Override + public SparkExpression generateBooleanExpression() { + return generateExpression(); + } + + @Override + public SparkSelect generateSelect() { + return new SparkSelect(); + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new SparkTableReference(t)).collect(Collectors.toList()); + } + + @Override + public List generateFetchColumns(boolean allowAggregates) { + if (Randomly.getBoolean()) { + return List.of(new SparkColumnReference(new SparkColumn("*", null, null))); + } + return Randomly.nonEmptySubset(columns).stream().map(c -> new SparkColumnReference(c)) + .collect(Collectors.toList()); + } + + @Override + public List getRandomJoinClauses() { + return List.of(); + } + + public enum SparkUnaryPrefixOperator implements Operator { + NOT("NOT"), PLUS("+"), MINUS("-"), BITWISE_NOT("~"); + + private String textRepr; + + SparkUnaryPrefixOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkUnaryPrefixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkUnaryPostfixOperator implements Operator { + IS_NULL("IS NULL"), IS_NOT_NULL("IS NOT NULL"); + + private String textRepr; + + SparkUnaryPostfixOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkUnaryPostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkBinaryComparisonOperator implements Operator { + EQUALS("="), GREATER(">"), GREATER_EQUALS(">="), SMALLER("<"), SMALLER_EQUALS("<="), NOT_EQUALS("!="), + LIKE("LIKE"), NOT_LIKE("NOT LIKE"), RLIKE("RLIKE"); + + private String textRepr; + + SparkBinaryComparisonOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkBinaryComparisonOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkBinaryLogicalOperator implements Operator { + AND("AND"), OR("OR"); + + private String textRepr; + + SparkBinaryLogicalOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkBinaryLogicalOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkBinaryArithmeticOperator implements Operator { + // Spark supports || for concat, and bitwise operators &, |, ^ + CONCAT("||"), ADD("+"), SUB("-"), MULT("*"), DIV("/"), MOD("%"), BITWISE_AND("&"), BITWISE_OR("|"), + BITWISE_XOR("^"); + + private String textRepr; + + SparkBinaryArithmeticOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkBinaryArithmeticOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkAggregateFunction { + COUNT(1), SUM(1), AVG(1), MIN(1), MAX(1), VARIANCE(1), VAR_SAMP(1), STDDEV_POP(1), STDDEV_SAMP(1), COVAR_POP(2), + COVAR_SAMP(2), CORR(2); + + private int nrArgs; + + SparkAggregateFunction(int nrArgs) { + this.nrArgs = nrArgs; + } + + public static SparkAggregateFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + return nrArgs; + } + } + + public enum SparkFunc { + ROUND(2), FLOOR(1), ABS(1), CEIL(1); + + private int nrArgs; + private boolean isVariadic; + + SparkFunc(int nrArgs) { + this(nrArgs, false); + } + + SparkFunc(int nrArgs, boolean isVariadic) { + this.nrArgs = nrArgs; + this.isVariadic = isVariadic; + } + + public static SparkFunc getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + if (isVariadic) { + return Randomly.smallNumber() + nrArgs; + } else { + return nrArgs; + } + } + } +} diff --git a/src/sqlancer/spark/gen/SparkInsertGenerator.java b/src/sqlancer/spark/gen/SparkInsertGenerator.java new file mode 100644 index 000000000..b1755a848 --- /dev/null +++ b/src/sqlancer/spark/gen/SparkInsertGenerator.java @@ -0,0 +1,47 @@ +package sqlancer.spark.gen; + +import java.util.List; + +import sqlancer.common.gen.AbstractInsertGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.spark.SparkErrors; +import sqlancer.spark.SparkGlobalState; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkTable; +import sqlancer.spark.SparkToStringVisitor; + +public class SparkInsertGenerator extends AbstractInsertGenerator { + + private final SparkGlobalState globalState; + private final ExpectedErrors errors = new ExpectedErrors(); + private final SparkExpressionGenerator gen; + + public SparkInsertGenerator(SparkGlobalState globalState) { + this.globalState = globalState; + this.gen = new SparkExpressionGenerator(globalState); + } + + public static SQLQueryAdapter getQuery(SparkGlobalState globalState) { + return new SparkInsertGenerator(globalState).generate(); + } + + @Override + protected void insertValue(SparkColumn column) { + sb.append(SparkToStringVisitor.asString(gen.generateConstant())); + } + + private SQLQueryAdapter generate() { + sb.append("INSERT INTO "); + SparkTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + + sb.append(" VALUES "); + + List columns = table.getColumns(); + insertColumns(columns); + + SparkErrors.addInsertErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, false, false); + } +} diff --git a/src/sqlancer/spark/gen/SparkTableGenerator.java b/src/sqlancer/spark/gen/SparkTableGenerator.java new file mode 100644 index 000000000..937e52248 --- /dev/null +++ b/src/sqlancer/spark/gen/SparkTableGenerator.java @@ -0,0 +1,102 @@ +package sqlancer.spark.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.spark.SparkErrors; +import sqlancer.spark.SparkGlobalState; +import sqlancer.spark.SparkSchema; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkDataType; +import sqlancer.spark.SparkSchema.SparkTable; +import sqlancer.spark.SparkToStringVisitor; + +public class SparkTableGenerator { + + private enum ColumnConstraints { + NOT_NULL, DEFAULT + // PRIMARY KEY and UNIQUE are often not supported in standard Spark file sources + // (Parquet/ORC) + // without specific catalogs (like Delta/Iceberg), so we limit to constraints + // Spark SQL widely accepts. + } + + private final SparkGlobalState globalState; + private final String tableName; + private final StringBuilder sb = new StringBuilder(); + private final SparkExpressionGenerator gen; + private final SparkTable table; + private final List columnsToBeAdded = new ArrayList<>(); + + public SparkTableGenerator(SparkGlobalState globalState, String tableName) { + this.tableName = tableName; + this.globalState = globalState; + this.table = new SparkTable(tableName, columnsToBeAdded, false); + this.gen = new SparkExpressionGenerator(globalState).setColumns(columnsToBeAdded); + } + + public static SQLQueryAdapter generate(SparkGlobalState globalState, String tableName) { + SparkTableGenerator generator = new SparkTableGenerator(globalState, tableName); + return generator.create(); + } + + private SQLQueryAdapter create() { + ExpectedErrors errors = new ExpectedErrors(); + + sb.append("CREATE TABLE "); + sb.append(globalState.getDatabaseName()); + sb.append("."); + sb.append(tableName); + sb.append(" ("); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + appendColumn(i); + } + sb.append(")"); + sb.append(" USING PARQUET"); + + // TODO: implement PARTITION BY clause + // TODO: implement CLUSTERED BY clauses + // TODO: implement ROW FORMAT and STORED AS clauses + // TODO: randomly add some predefined TABLEPROPERTIES + + SparkErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, true, false); + } + + private void appendColumn(int columnId) { + String columnName = DBMSCommon.createColumnName(columnId); + sb.append(columnName); + sb.append(" "); + SparkDataType randType = SparkSchema.SparkDataType.getRandomType(); + sb.append(randType); + columnsToBeAdded.add(new SparkColumn(columnName, table, randType)); + appendColumnConstraint(); + } + + private void appendColumnConstraint() { + if (Randomly.getBoolean()) { + return; + } + + ColumnConstraints constraint = Randomly.fromOptions(ColumnConstraints.values()); + switch (constraint) { + case NOT_NULL: + sb.append(" NOT NULL"); + break; + case DEFAULT: + sb.append(" DEFAULT "); + sb.append(SparkToStringVisitor.asString(gen.generateConstant())); + sb.append(" "); + break; + default: + throw new AssertionError(constraint); + } + } +} diff --git a/src/sqlancer/sqlite3/SQLite3Errors.java b/src/sqlancer/sqlite3/SQLite3Errors.java index b872fcd1d..96e45c549 100644 --- a/src/sqlancer/sqlite3/SQLite3Errors.java +++ b/src/sqlancer/sqlite3/SQLite3Errors.java @@ -1,6 +1,8 @@ package sqlancer.sqlite3; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import sqlancer.common.query.ExpectedErrors; @@ -9,19 +11,28 @@ public final class SQLite3Errors { private SQLite3Errors() { } - public static void addDeleteErrors(ExpectedErrors errors) { + public static List getDeleteErrors() { + ArrayList errors = new ArrayList<>(); + // DELETE trigger for a view/table to which colomns were added or deleted errors.add("columns but"); // trigger with on conflict clause errors.add("ON CONFLICT clause does not match any PRIMARY KEY or UNIQUE constraint"); + + return errors; } - public static void addExpectedExpressionErrors(ExpectedErrors errors) { + public static void addDeleteErrors(ExpectedErrors errors) { + errors.addAll(getDeleteErrors()); + } + + public static List getExpectedExpressionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("[SQLITE_BUSY] The database file is locked"); errors.add("FTS expression tree is too large"); errors.add("String or BLOB exceeds size limit"); errors.add("[SQLITE_ERROR] SQL error or missing database (integer overflow)"); - errors.add("second argument to likelihood() must be a constant between 0.0 and 1.0"); errors.add("ORDER BY term out of range"); errors.add("GROUP BY term out of range"); errors.add("not authorized"); // load_extension @@ -45,9 +56,11 @@ public static void addExpectedExpressionErrors(ExpectedErrors errors) { errors.add("malformed JSON"); errors.add("JSON cannot hold BLOB values"); errors.add("JSON path error"); + errors.add("bad JSON path"); errors.add("json_insert() needs an odd number of arguments"); errors.add("json_object() labels must be TEXT"); errors.add("json_object() requires an even number of arguments"); + errors.add("argument of ntile must be a positive integer"); // fts5 functions errors.add("unable to use function highlight in the requested context"); @@ -62,9 +75,21 @@ public static void addExpectedExpressionErrors(ExpectedErrors errors) { errors.add("ORDER BY clause should come after"); errors.add("LIMIT clause should come after"); + errors.add("unsafe use of load_extension"); + errors.add("table does not support scanning"); + errors.add("circularly defined"); + errors.add("[SQLITE_ERROR] SQL error or missing database"); // A possible delay in the execution of DROP TABLE + // statement. + return errors; } - public static void addMatchQueryErrors(ExpectedErrors errors) { + public static void addExpectedExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpectedExpressionErrors()); + } + + public static List getMatchQueryErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("unable to use function MATCH in the requested context"); errors.add("malformed MATCH expression"); errors.add("fts5: syntax error near"); @@ -73,28 +98,60 @@ public static void addMatchQueryErrors(ExpectedErrors errors) { errors.add("fts5: column queries are not supported"); // vt0.c0 MATCH '2016456922' errors.add("fts5: phrase queries are not supported"); errors.add("unterminated string"); + + return errors; } - public static void addTableManipulationErrors(ExpectedErrors errors) { + public static void addMatchQueryErrors(ExpectedErrors errors) { + errors.addAll(getMatchQueryErrors()); + } + + public static List getTableManipulationErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("unsupported frame specification"); errors.add("non-deterministic functions prohibited in CHECK constraints"); errors.addAll(Arrays.asList("subqueries prohibited in CHECK constraints", "generated columns cannot be part of the PRIMARY KEY", "must have at least one non-generated column")); + + return errors; } - public static void addQueryErrors(ExpectedErrors errors) { + public static void addTableManipulationErrors(ExpectedErrors errors) { + errors.addAll(getTableManipulationErrors()); + } + + public static List getQueryErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("ON clause references tables to its right"); + + return errors; } - public static void addInsertNowErrors(ExpectedErrors errors) { + public static void addQueryErrors(ExpectedErrors errors) { + errors.addAll(getQueryErrors()); + } + + public static List getInsertNowErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("non-deterministic use of strftime()"); errors.add("non-deterministic use of time()"); errors.add("non-deterministic use of datetime()"); errors.add("non-deterministic use of julianday()"); errors.add("non-deterministic use of date()"); + + return errors; } - public static void addInsertUpdateErrors(ExpectedErrors errors) { + public static void addInsertNowErrors(ExpectedErrors errors) { + errors.addAll(getInsertNowErrors()); + } + + public static List getInsertUpdateErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("String or BLOB exceeds size limit"); errors.add("[SQLITE_CONSTRAINT_CHECK]"); errors.add("[SQLITE_CONSTRAINT_PRIMARYKEY]"); @@ -109,6 +166,12 @@ public static void addInsertUpdateErrors(ExpectedErrors errors) { errors.add("[SQLITE_ERROR] SQL error or missing database (no such table:"); errors.add("[SQLITE_ERROR] SQL error or missing database (foreign key mismatch"); errors.add("no such column"); // trigger + + return errors; + } + + public static void addInsertUpdateErrors(ExpectedErrors errors) { + errors.addAll(getInsertUpdateErrors()); } } diff --git a/src/sqlancer/sqlite3/SQLite3ExpectedValueVisitor.java b/src/sqlancer/sqlite3/SQLite3ExpectedValueVisitor.java index 1c148ee86..7ca82ec02 100644 --- a/src/sqlancer/sqlite3/SQLite3ExpectedValueVisitor.java +++ b/src/sqlancer/sqlite3/SQLite3ExpectedValueVisitor.java @@ -14,14 +14,21 @@ import sqlancer.sqlite3.ast.SQLite3Expression.InOperation; import sqlancer.sqlite3.ast.SQLite3Expression.Join; import sqlancer.sqlite3.ast.SQLite3Expression.MatchOperation; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Alias; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ColumnName; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Distinct; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Exist; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ExpressionBag; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3OrderingTerm; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixText; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixUnaryOperation; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ResultMap; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3TableAndColumnRef; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3TableReference; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Text; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Typeof; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Values; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3WithClause; import sqlancer.sqlite3.ast.SQLite3Expression.Sqlite3BinaryOperation; import sqlancer.sqlite3.ast.SQLite3Expression.Subquery; import sqlancer.sqlite3.ast.SQLite3Expression.TypeLiteral; @@ -303,4 +310,43 @@ public void visit(SQLite3SetClause set) { visit(set.getRight()); } + @Override + public void visit(SQLite3Alias alias) { + print(alias); + print(alias.getOriginalExpression()); + print(alias.getAliasExpression()); + } + + @Override + public void visit(SQLite3WithClause withClause) { + print(withClause); + print(withClause.getLeft()); + print(withClause.getRight()); + } + + @Override + public void visit(SQLite3TableAndColumnRef tableAndColumnRef) { + print(tableAndColumnRef); + } + + @Override + public void visit(SQLite3Values values) { + print(values); + } + + @Override + public void visit(SQLite3ExpressionBag expr) { + print(expr); + print(expr.getInnerExpr()); + } + + @Override + public void visit(SQLite3Typeof expr) { + print(expr); + print(expr.getInnerExpr()); + } + + @Override + public void visit(SQLite3ResultMap tableSummary) { + } } diff --git a/src/sqlancer/sqlite3/SQLite3Options.java b/src/sqlancer/sqlite3/SQLite3Options.java index dabc3ce13..e9e34892e 100644 --- a/src/sqlancer/sqlite3/SQLite3Options.java +++ b/src/sqlancer/sqlite3/SQLite3Options.java @@ -1,7 +1,5 @@ package sqlancer.sqlite3; -import java.sql.SQLException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -9,18 +7,6 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.CompositeTestOracle; -import sqlancer.common.oracle.TestOracle; -import sqlancer.sqlite3.SQLite3Options.SQLite3OracleFactory; -import sqlancer.sqlite3.oracle.SQLite3Fuzzer; -import sqlancer.sqlite3.oracle.SQLite3NoRECOracle; -import sqlancer.sqlite3.oracle.SQLite3PivotedQuerySynthesisOracle; -import sqlancer.sqlite3.oracle.tlp.SQLite3TLPAggregateOracle; -import sqlancer.sqlite3.oracle.tlp.SQLite3TLPDistinctOracle; -import sqlancer.sqlite3.oracle.tlp.SQLite3TLPGroupByOracle; -import sqlancer.sqlite3.oracle.tlp.SQLite3TLPHavingOracle; -import sqlancer.sqlite3.oracle.tlp.SQLite3TLPWhereOracle; @Parameters(separators = "=", commandDescription = "SQLite3") public class SQLite3Options implements DBMSSpecificOptions { @@ -87,83 +73,31 @@ public class SQLite3Options implements DBMSSpecificOptions public boolean generateDatabase = true; @Parameter(names = { - "--execute-queries" }, description = "Specifies whether the query in the fuzzer should be executed", arity = 1) - public boolean executeQuery = true; - - public enum SQLite3OracleFactory implements OracleFactory { - PQS { - @Override - public TestOracle create(SQLite3GlobalState globalState) throws SQLException { - return new SQLite3PivotedQuerySynthesisOracle(globalState); - } - - @Override - public boolean requiresAllTablesToContainRows() { - return true; - } - - }, - NoREC { - @Override - public TestOracle create(SQLite3GlobalState globalState) throws SQLException { - return new SQLite3NoRECOracle(globalState); - } - }, - AGGREGATE { - - @Override - public TestOracle create(SQLite3GlobalState globalState) throws SQLException { - return new SQLite3TLPAggregateOracle(globalState); - } - - }, - WHERE { - - @Override - public TestOracle create(SQLite3GlobalState globalState) throws SQLException { - return new SQLite3TLPWhereOracle(globalState); - } - - }, - DISTINCT { - @Override - public TestOracle create(SQLite3GlobalState globalState) throws SQLException { - return new SQLite3TLPDistinctOracle(globalState); - } - }, - GROUP_BY { - @Override - public TestOracle create(SQLite3GlobalState globalState) throws SQLException { - return new SQLite3TLPGroupByOracle(globalState); - } - }, - HAVING { - @Override - public TestOracle create(SQLite3GlobalState globalState) throws SQLException { - return new SQLite3TLPHavingOracle(globalState); - } - }, - FUZZER { - @Override - public TestOracle create(SQLite3GlobalState globalState) throws SQLException { - return new SQLite3Fuzzer(globalState); - } - }, - QUERY_PARTITIONING { - @Override - public TestOracle create(SQLite3GlobalState globalState) throws SQLException { - List oracles = new ArrayList<>(); - oracles.add(new SQLite3TLPWhereOracle(globalState)); - oracles.add(new SQLite3TLPDistinctOracle(globalState)); - oracles.add(new SQLite3TLPGroupByOracle(globalState)); - oracles.add(new SQLite3TLPHavingOracle(globalState)); - oracles.add(new SQLite3TLPAggregateOracle(globalState)); - return new CompositeTestOracle(oracles, globalState); - } - }; + "--max-num-tables" }, description = "The maximum number of tables/virtual tables/ rtree tables/ views that can be created") + public int maxNumTables = 10; + @Parameter(names = { "--max-num-indexes" }, description = "The maximum number of indexes that can be created") + public int maxNumIndexes = 20; + + public enum CODDTestModel { + RANDOM, EXPRESSION, SUBQUERY; + + public boolean isRandom() { + return this == RANDOM; + } + + public boolean isExpression() { + return this == EXPRESSION; + } + + public boolean isSubquery() { + return this == SUBQUERY; + } } + @Parameter(names = { "--coddtest-model" }, description = "Apply CODDTest on EXPRESSION, SUBQUERY, or RANDOM") + public CODDTestModel coddTestModel = CODDTestModel.RANDOM; + @Override public List getTestOracleFactory() { return Arrays.asList(oracles); diff --git a/src/sqlancer/sqlite3/SQLite3OracleFactory.java b/src/sqlancer/sqlite3/SQLite3OracleFactory.java new file mode 100644 index 000000000..2e2f9f0f7 --- /dev/null +++ b/src/sqlancer/sqlite3/SQLite3OracleFactory.java @@ -0,0 +1,113 @@ +package sqlancer.sqlite3; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.sqlite3.gen.SQLite3ExpressionGenerator; +import sqlancer.sqlite3.oracle.SQLite3CODDTestOracle; +import sqlancer.sqlite3.oracle.SQLite3Fuzzer; +import sqlancer.sqlite3.oracle.SQLite3PivotedQuerySynthesisOracle; +import sqlancer.sqlite3.oracle.tlp.SQLite3TLPAggregateOracle; +import sqlancer.sqlite3.oracle.tlp.SQLite3TLPDistinctOracle; +import sqlancer.sqlite3.oracle.tlp.SQLite3TLPGroupByOracle; +import sqlancer.sqlite3.oracle.tlp.SQLite3TLPHavingOracle; + +public enum SQLite3OracleFactory implements OracleFactory { + PQS { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws SQLException { + return new SQLite3PivotedQuerySynthesisOracle(globalState); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + + }, + NoREC { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws SQLException { + SQLite3ExpressionGenerator gen = new SQLite3ExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(SQLite3Errors.getExpectedExpressionErrors()) + .with(SQLite3Errors.getMatchQueryErrors()).with(SQLite3Errors.getQueryErrors()) + .with("misuse of aggregate", "misuse of window function", + "second argument to nth_value must be a positive integer", "no such table", + "no query solution", "unable to use function MATCH in the requested context") + .build(); + return new NoRECOracle<>(globalState, gen, errors); + } + }, + AGGREGATE { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws SQLException { + return new SQLite3TLPAggregateOracle(globalState); + } + + }, + WHERE { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws SQLException { + SQLite3ExpressionGenerator gen = new SQLite3ExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(SQLite3Errors.getExpectedExpressionErrors()) + .build(); + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + + }, + DISTINCT { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws SQLException { + return new SQLite3TLPDistinctOracle(globalState); + } + }, + GROUP_BY { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws SQLException { + return new SQLite3TLPGroupByOracle(globalState); + } + }, + HAVING { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws SQLException { + return new SQLite3TLPHavingOracle(globalState); + } + }, + FUZZER { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws SQLException { + return new SQLite3Fuzzer(globalState); + } + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(WHERE.create(globalState)); + oracles.add(DISTINCT.create(globalState)); + oracles.add(GROUP_BY.create(globalState)); + oracles.add(HAVING.create(globalState)); + oracles.add(AGGREGATE.create(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + }, + CODDTest { + @Override + public TestOracle create(SQLite3GlobalState globalState) throws SQLException { + return new SQLite3CODDTestOracle(globalState); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + }; + +} diff --git a/src/sqlancer/sqlite3/SQLite3Provider.java b/src/sqlancer/sqlite3/SQLite3Provider.java index eb88a0e0a..5fbd3b471 100644 --- a/src/sqlancer/sqlite3/SQLite3Provider.java +++ b/src/sqlancer/sqlite3/SQLite3Provider.java @@ -1,11 +1,13 @@ package sqlancer.sqlite3; import java.io.File; +import java.io.IOException; import java.sql.DriverManager; import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; import com.google.auto.service.AutoService; @@ -20,7 +22,7 @@ import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLQueryProvider; -import sqlancer.sqlite3.SQLite3Options.SQLite3OracleFactory; +import sqlancer.common.query.SQLancerResultSet; import sqlancer.sqlite3.gen.SQLite3AnalyzeGenerator; import sqlancer.sqlite3.gen.SQLite3CreateVirtualRtreeTabelGenerator; import sqlancer.sqlite3.gen.SQLite3ExplainGenerator; @@ -58,47 +60,50 @@ public SQLite3Provider() { } public enum Action implements AbstractAction { - PRAGMA(SQLite3PragmaGenerator::insertPragma), // - INDEX(SQLite3IndexGenerator::insertIndex), // - INSERT(SQLite3InsertGenerator::insertRow), // - VACUUM(SQLite3VacuumGenerator::executeVacuum), // - REINDEX(SQLite3ReindexGenerator::executeReindex), // - ANALYZE(SQLite3AnalyzeGenerator::generateAnalyze), // - DELETE(SQLite3DeleteGenerator::deleteContent), // + PRAGMA(SQLite3PragmaGenerator::insertPragma), // 0 + CREATE_INDEX(SQLite3IndexGenerator::insertIndex), // 1 + CREATE_VIEW(SQLite3ViewGenerator::generate), // 2 + CREATE_TRIGGER(SQLite3CreateTriggerGenerator::create), // 3 + CREATE_TABLE(SQLite3TableGenerator::createRandomTableStatement), // 4 + CREATE_VIRTUALTABLE(SQLite3CreateVirtualFTSTableGenerator::createRandomTableStatement), // 5 + CREATE_RTREETABLE(SQLite3CreateVirtualRtreeTabelGenerator::createRandomTableStatement), // 6 + INSERT(SQLite3InsertGenerator::insertRow), // 7 + DELETE(SQLite3DeleteGenerator::deleteContent), // 8 + ALTER(SQLite3AlterTable::alterTable), // 9 + UPDATE(SQLite3UpdateGenerator::updateRow), // 10 + DROP_INDEX(SQLite3DropIndexGenerator::dropIndex), // 11 + DROP_TABLE(SQLite3DropTableGenerator::dropTable), // 12 + DROP_VIEW(SQLite3ViewGenerator::dropView), // 13 + VACUUM(SQLite3VacuumGenerator::executeVacuum), // 14 + REINDEX(SQLite3ReindexGenerator::executeReindex), // 15 + ANALYZE(SQLite3AnalyzeGenerator::generateAnalyze), // 16 + EXPLAIN(SQLite3ExplainGenerator::explain), // 17 + CHECK_RTREE_TABLE((g) -> { + SQLite3Table table = g.getSchema().getRandomTableOrBailout(t -> t.getName().startsWith("r")); + String format = String.format("SELECT rtreecheck('%s');", table.getName()); + return new SQLQueryAdapter(format, ExpectedErrors.from("The database file is locked")); + }), // 18 + VIRTUAL_TABLE_ACTION(SQLite3VirtualFTSTableCommandGenerator::create), // 19 + MANIPULATE_STAT_TABLE(SQLite3StatTableGenerator::getQuery), // 20 TRANSACTION_START(SQLite3TransactionGenerator::generateBeginTransaction) { @Override public boolean canBeRetried() { return false; } - }, // - ALTER(SQLite3AlterTable::alterTable), // - DROP_INDEX(SQLite3DropIndexGenerator::dropIndex), // - UPDATE(SQLite3UpdateGenerator::updateRow), // + }, // 21 ROLLBACK_TRANSACTION(SQLite3TransactionGenerator::generateRollbackTransaction) { @Override public boolean canBeRetried() { return false; } - }, // + }, // 22 COMMIT(SQLite3TransactionGenerator::generateCommit) { @Override public boolean canBeRetried() { return false; } - }, // - DROP_TABLE(SQLite3DropTableGenerator::dropTable), // - DROP_VIEW(SQLite3ViewGenerator::dropView), // - EXPLAIN(SQLite3ExplainGenerator::explain), // - CHECK_RTREE_TABLE((g) -> { - SQLite3Table table = g.getSchema().getRandomTableOrBailout(t -> t.getName().startsWith("r")); - String format = String.format("SELECT rtreecheck('%s');", table.getName()); - return new SQLQueryAdapter(format, ExpectedErrors.from("The database file is locked")); - }), // - VIRTUAL_TABLE_ACTION(SQLite3VirtualFTSTableCommandGenerator::create), // - CREATE_VIEW(SQLite3ViewGenerator::generate), // - CREATE_TRIGGER(SQLite3CreateTriggerGenerator::create), // - MANIPULATE_STAT_TABLE(SQLite3StatTableGenerator::getQuery); + }; // 23 private final SQLQueryProvider sqlQueryProvider; @@ -146,7 +151,7 @@ private static int mapActions(SQLite3GlobalState globalState, Action a) { case MANIPULATE_STAT_TABLE: nrPerformed = r.getInteger(0, 5); break; - case INDEX: + case CREATE_INDEX: nrPerformed = r.getInteger(0, 5); break; case VIRTUAL_TABLE_ACTION: @@ -156,6 +161,11 @@ private static int mapActions(SQLite3GlobalState globalState, Action a) { case PRAGMA: nrPerformed = r.getInteger(0, 20); break; + case CREATE_TABLE: + case CREATE_VIRTUALTABLE: + case CREATE_RTREETABLE: + nrPerformed = 0; + break; case TRANSACTION_START: case REINDEX: case ANALYZE: @@ -291,4 +301,59 @@ public SQLConnection createDatabase(SQLite3GlobalState globalState) throws SQLEx public String getDBMSName() { return "sqlite3"; } + + @Override + public String getQueryPlan(String selectStr, SQLite3GlobalState globalState) throws Exception { + String queryPlan = ""; + if (globalState.getOptions().logEachSelect()) { + globalState.getLogger().writeCurrent(selectStr); + try { + globalState.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + // Set up the expected errors for NoREC oracle. + ExpectedErrors errors = new ExpectedErrors(); + SQLite3Errors.addExpectedExpressionErrors(errors); + SQLite3Errors.addMatchQueryErrors(errors); + SQLite3Errors.addQueryErrors(errors); + SQLite3Errors.addInsertUpdateErrors(errors); + + SQLQueryAdapter q = new SQLQueryAdapter(SQLite3ExplainGenerator.explain(selectStr), errors); + try (SQLancerResultSet rs = q.executeAndGet(globalState)) { + if (rs != null) { + while (rs.next()) { + queryPlan += rs.getString(4) + ";"; + } + } + } catch (SQLException | AssertionError e) { + queryPlan = ""; + } + return queryPlan; + } + + @Override + protected double[] initializeWeightedAverageReward() { + return new double[Action.values().length]; + } + + @Override + protected void executeMutator(int index, SQLite3GlobalState globalState) throws Exception { + SQLQueryAdapter queryMutateTable = Action.values()[index].getQuery(globalState); + globalState.executeStatement(queryMutateTable); + + } + + @Override + protected boolean addRowsToAllTables(SQLite3GlobalState globalState) throws Exception { + List tablesNoRow = globalState.getSchema().getDatabaseTables().stream() + .filter(t -> t.getNrRows(globalState) == 0).collect(Collectors.toList()); + for (SQLite3Table table : tablesNoRow) { + SQLQueryAdapter queryAddRows = SQLite3InsertGenerator.insertRow(globalState, table); + globalState.executeStatement(queryAddRows); + } + + return true; + } } diff --git a/src/sqlancer/sqlite3/SQLite3ToStringVisitor.java b/src/sqlancer/sqlite3/SQLite3ToStringVisitor.java index f81259732..8fdbf4438 100644 --- a/src/sqlancer/sqlite3/SQLite3ToStringVisitor.java +++ b/src/sqlancer/sqlite3/SQLite3ToStringVisitor.java @@ -1,6 +1,9 @@ package sqlancer.sqlite3; import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import sqlancer.Randomly; import sqlancer.common.visitor.ToStringVisitor; @@ -11,6 +14,7 @@ import sqlancer.sqlite3.ast.SQLite3Case.SQLite3CaseWithoutBaseExpression; import sqlancer.sqlite3.ast.SQLite3Cast; import sqlancer.sqlite3.ast.SQLite3Constant; +import sqlancer.sqlite3.ast.SQLite3Constant.SQLite3NullConstant; import sqlancer.sqlite3.ast.SQLite3Expression; import sqlancer.sqlite3.ast.SQLite3Expression.BetweenOperation; import sqlancer.sqlite3.ast.SQLite3Expression.Cast; @@ -19,12 +23,19 @@ import sqlancer.sqlite3.ast.SQLite3Expression.InOperation; import sqlancer.sqlite3.ast.SQLite3Expression.Join; import sqlancer.sqlite3.ast.SQLite3Expression.MatchOperation; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Alias; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ColumnName; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Distinct; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Exist; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ExpressionBag; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3OrderingTerm; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ResultMap; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3TableAndColumnRef; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3TableReference; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Text; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Typeof; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Values; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3WithClause; import sqlancer.sqlite3.ast.SQLite3Expression.Subquery; import sqlancer.sqlite3.ast.SQLite3Expression.TypeLiteral; import sqlancer.sqlite3.ast.SQLite3Function; @@ -35,6 +46,7 @@ import sqlancer.sqlite3.ast.SQLite3WindowFunctionExpression; import sqlancer.sqlite3.ast.SQLite3WindowFunctionExpression.SQLite3WindowFunctionFrameSpecBetween; import sqlancer.sqlite3.ast.SQLite3WindowFunctionExpression.SQLite3WindowFunctionFrameSpecTerm; +import sqlancer.sqlite3.schema.SQLite3DataType; public class SQLite3ToStringVisitor extends ToStringVisitor implements SQLite3Visitor { @@ -99,6 +111,10 @@ public void visit(SQLite3Select s, boolean inner) { if (inner) { sb.append("("); } + if (s.getWithClause() != null) { + visit(s.getWithClause()); + sb.append(" "); + } sb.append("SELECT "); switch (s.getFromOptions()) { case DISTINCT: @@ -139,7 +155,7 @@ public void visit(SQLite3Select s, boolean inner) { visit(whereClause); sb.append(")"); } - if (s.getGroupByClause().size() > 0) { + if (!s.getGroupByClause().isEmpty()) { sb.append(" "); sb.append("GROUP BY "); visit(s.getGroupByClause()); @@ -148,9 +164,9 @@ public void visit(SQLite3Select s, boolean inner) { sb.append(" HAVING "); visit(s.getHavingClause()); } - if (!s.getOrderByClause().isEmpty()) { + if (!s.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - visit(s.getOrderByClause()); + visit(s.getOrderByClauses()); } if (s.getLimitClause() != null) { sb.append(" LIMIT "); @@ -237,6 +253,12 @@ public void visit(Join join) { case OUTER: sb.append("LEFT OUTER"); break; + case RIGHT: + sb.append("RIGHT OUTER"); + break; + case FULL: + sb.append("FULL OUTER"); + break; default: throw new AssertionError(join.getType()); } @@ -282,13 +304,20 @@ public void visit(InOperation op) { sb.append("("); visit(op.getLeft()); sb.append(" IN "); - sb.append("("); if (op.getRightExpressionList() != null) { + sb.append("("); visit(op.getRightExpressionList()); + sb.append(")"); } else { - visit(op.getRightSelect()); + if (op.getRightSelect() instanceof SQLite3Expression.SQLite3TableReference) { + visit(op.getRightSelect()); + } else { + sb.append("("); + visit(op.getRightSelect()); + sb.append(")"); + } } - sb.append(")"); + sb.append(")"); } @@ -299,6 +328,9 @@ public void visit(Subquery query) { @Override public void visit(SQLite3Exist exist) { + if (exist.getNegated()) { + sb.append(" NOT"); + } sb.append(" EXISTS "); if (exist.getExpression() instanceof SQLite3SetClause) { sb.append("("); @@ -476,4 +508,104 @@ public void visit(SQLite3SetClause set) { sb.append(SQLite3Visitor.asString(set.getRight())); } + @Override + public void visit(SQLite3Alias alias) { + sb.append("("); + visit(alias.getOriginalExpression()); + sb.append(")"); + sb.append(" AS "); + visit(alias.getAliasExpression()); + } + + @Override + public void visit(SQLite3WithClause withClause) { + sb.append("WITH "); + visit(withClause.getLeft()); + sb.append(" AS "); + visit(withClause.getRight()); + } + + @Override + public void visit(SQLite3TableAndColumnRef tableAndColumnRef) { + sb.append(tableAndColumnRef.getString()); + } + + @Override + public void visit(SQLite3Values values) { + Map> vs = values.getValues(); + int size = vs.get(vs.keySet().iterator().next()).size(); + List columnNames = values.getColumns().stream().map(c -> c.getName()).collect(Collectors.toList()); + sb.append("(VALUES "); + for (int i = 0; i < size; i++) { + sb.append("("); + Boolean isFirstColumn = true; + for (String name : columnNames) { + if (!isFirstColumn) { + sb.append(", "); + } + if (vs.get(name).get(i).getDataType() == SQLite3DataType.NULL) { + visit(vs.get(name).get(i)); + } else { + sb.append("(CAST("); + visit(vs.get(name).get(i)); + sb.append(" AS "); + sb.append(vs.get(name).get(i).getDataType().toString()); + sb.append("))"); + } + isFirstColumn = false; + } + sb.append(")"); + if (i < size - 1) { + sb.append(", "); + } + } + sb.append(")"); + } + + @Override + public void visit(SQLite3ExpressionBag expr) { + visit(expr.getInnerExpr()); + } + + @Override + public void visit(SQLite3Typeof expr) { + sb.append("typeof("); + visit(expr.getInnerExpr()); + sb.append(")"); + } + + @Override + public void visit(SQLite3ResultMap tableSummary) { + // We use the CASE WHEN THEN END expression to represent the result of an expression for each row in the table. + SQLite3Values values = tableSummary.getValues(); + List columnRefs = tableSummary.getColumns(); + List summary = tableSummary.getSummary(); + + Map> vs = values.getValues(); + int size = vs.get(vs.keySet().iterator().next()).size(); + if (size == 0) { + throw new AssertionError("The result of the expression must not be empty."); + } + List columnNames = values.getColumns().stream().map(c -> c.getName()).collect(Collectors.toList()); + sb.append(" CASE "); + for (int i = 0; i < size; i++) { + sb.append("WHEN "); + for (int j = 0; j < columnNames.size(); ++j) { + visit(columnRefs.get(j)); + if (vs.get(columnNames.get(j)).get(i) instanceof SQLite3NullConstant) { + sb.append(" IS NULL"); + } else { + sb.append(" = "); + sb.append(vs.get(columnNames.get(j)).get(i).toString()); + } + if (j < columnNames.size() - 1) { + sb.append(" AND "); + } + } + sb.append(" THEN "); + visit(summary.get(i)); + sb.append(" "); + } + sb.append("END "); + } } diff --git a/src/sqlancer/sqlite3/SQLite3Visitor.java b/src/sqlancer/sqlite3/SQLite3Visitor.java index f891c17f0..02ac2b4d7 100644 --- a/src/sqlancer/sqlite3/SQLite3Visitor.java +++ b/src/sqlancer/sqlite3/SQLite3Visitor.java @@ -13,14 +13,21 @@ import sqlancer.sqlite3.ast.SQLite3Expression.InOperation; import sqlancer.sqlite3.ast.SQLite3Expression.Join; import sqlancer.sqlite3.ast.SQLite3Expression.MatchOperation; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Alias; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ColumnName; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Distinct; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Exist; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ExpressionBag; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3OrderingTerm; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixText; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixUnaryOperation; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ResultMap; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3TableAndColumnRef; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3TableReference; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Text; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Typeof; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Values; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3WithClause; import sqlancer.sqlite3.ast.SQLite3Expression.Sqlite3BinaryOperation; import sqlancer.sqlite3.ast.SQLite3Expression.Subquery; import sqlancer.sqlite3.ast.SQLite3Expression.TypeLiteral; @@ -130,6 +137,20 @@ default void visit(SQLite3PostfixUnaryOperation exp) { void visit(SQLite3WindowFunctionFrameSpecBetween between); + void visit(SQLite3Alias alias); + + void visit(SQLite3WithClause withClause); + + void visit(SQLite3TableAndColumnRef tableAndColumnRef); + + void visit(SQLite3Values values); + + void visit(SQLite3ExpressionBag expr); + + void visit(SQLite3Typeof expr); + + void visit(SQLite3ResultMap tableSummary); + default void visit(SQLite3Expression expr) { if (expr instanceof Sqlite3BinaryOperation) { visit((Sqlite3BinaryOperation) expr); @@ -193,6 +214,20 @@ default void visit(SQLite3Expression expr) { visit((SQLite3TableReference) expr); } else if (expr instanceof SQLite3SetClause) { visit((SQLite3SetClause) expr); + } else if (expr instanceof SQLite3Alias) { + visit((SQLite3Alias) expr); + } else if (expr instanceof SQLite3WithClause) { + visit((SQLite3WithClause) expr); + } else if (expr instanceof SQLite3TableAndColumnRef) { + visit((SQLite3TableAndColumnRef) expr); + } else if (expr instanceof SQLite3Values) { + visit((SQLite3Values) expr); + } else if (expr instanceof SQLite3ExpressionBag) { + visit((SQLite3ExpressionBag) expr); + } else if (expr instanceof SQLite3Typeof) { + visit((SQLite3Typeof) expr); + } else if (expr instanceof SQLite3ResultMap) { + visit((SQLite3ResultMap) expr); } else { throw new AssertionError(expr); } diff --git a/src/sqlancer/sqlite3/ast/SQLite3Expression.java b/src/sqlancer/sqlite3/ast/SQLite3Expression.java index 748c70375..ad0f64ad6 100644 --- a/src/sqlancer/sqlite3/ast/SQLite3Expression.java +++ b/src/sqlancer/sqlite3/ast/SQLite3Expression.java @@ -1,11 +1,13 @@ package sqlancer.sqlite3.ast; import java.util.List; +import java.util.Map; import java.util.Optional; import sqlancer.IgnoreMeException; import sqlancer.LikeImplementationHelper; import sqlancer.Randomly; +import sqlancer.common.ast.newast.Expression; import sqlancer.common.visitor.BinaryOperation; import sqlancer.common.visitor.UnaryOperation; import sqlancer.sqlite3.SQLite3CollateHelper; @@ -18,7 +20,7 @@ import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Column.SQLite3CollateSequence; import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Table; -public abstract class SQLite3Expression { +public abstract class SQLite3Expression implements Expression { public static class SQLite3TableReference extends SQLite3Expression { @@ -112,9 +114,19 @@ public SQLite3CollateSequence getImplicitCollateSequence() { public static class SQLite3Exist extends SQLite3Expression { private final SQLite3Expression select; + private boolean negated; - public SQLite3Exist(SQLite3Expression select) { + public SQLite3Exist(SQLite3Expression select, boolean negated) { this.select = select; + this.negated = negated; + } + + public void setNegated(boolean negated) { + this.negated = negated; + } + + public boolean getNegated() { + return this.negated; } public SQLite3Expression getExpression() { @@ -128,10 +140,11 @@ public SQLite3CollateSequence getExplicitCollateSequence() { } - public static class Join extends SQLite3Expression { + public static class Join extends SQLite3Expression + implements sqlancer.common.ast.newast.Join { public enum JoinType { - INNER, CROSS, OUTER, NATURAL; + INNER, CROSS, OUTER, NATURAL, RIGHT, FULL; } private final SQLite3Table table; @@ -176,6 +189,7 @@ public SQLite3CollateSequence getExplicitCollateSequence() { return null; } + @Override public void setOnClause(SQLite3Expression onClause) { this.onClause = onClause; } @@ -183,7 +197,6 @@ public void setOnClause(SQLite3Expression onClause) { public void setType(JoinType type) { this.type = type; } - } public static class Subquery extends SQLite3Expression { @@ -1550,4 +1563,198 @@ public boolean omitBracketsWhenPrinting() { } } + public static class SQLite3WithClause extends SQLite3Expression { + + private final SQLite3Expression left; + private SQLite3Expression right; + + public SQLite3WithClause(SQLite3Expression left, SQLite3Expression right) { + this.left = left; + this.right = right; + } + + public SQLite3Expression getLeft() { + return this.left; + } + + public SQLite3Expression getRight() { + return this.right; + } + + public void updateRight(SQLite3Expression right) { + this.right = right; + } + + @Override + public SQLite3CollateSequence getExplicitCollateSequence() { + return null; + } + } + + public static class SQLite3Alias extends SQLite3Expression { + + private final SQLite3Expression originalExpression; + private final SQLite3Expression aliasExpression; + + public SQLite3Alias(SQLite3Expression originalExpression, SQLite3Expression aliasExpression) { + this.originalExpression = originalExpression; + this.aliasExpression = aliasExpression; + } + + @Override + public SQLite3CollateSequence getExplicitCollateSequence() { + return null; + } + + public SQLite3Expression getOriginalExpression() { + return originalExpression; + } + + public SQLite3Expression getAliasExpression() { + return aliasExpression; + } + } + + public static class SQLite3TableAndColumnRef extends SQLite3Expression { + + private final SQLite3Table table; + + public SQLite3TableAndColumnRef(SQLite3Table table) { + this.table = table; + } + + public SQLite3Table getTable() { + return this.table; + } + + public String getString() { + StringBuilder sb = new StringBuilder(); + sb.append(table.getName()); + sb.append("("); + Boolean isFirstColumn = true; + for (SQLite3Column c : this.table.getColumns()) { + if (!isFirstColumn) { + sb.append(", "); + } + sb.append(c.getName()); + isFirstColumn = false; + } + sb.append(")"); + return sb.toString(); + } + + @Override + public SQLite3CollateSequence getExplicitCollateSequence() { + return null; + } + } + + public static class SQLite3Values extends SQLite3Expression { + + private final Map> values; + private final List columns; + + public SQLite3Values(Map> values, List columns) { + this.values = values; + this.columns = columns; + } + + public Map> getValues() { + return this.values; + } + + public List getColumns() { + return this.columns; + } + + @Override + public SQLite3CollateSequence getExplicitCollateSequence() { + return null; + } + } + + // The ExpressionBag is not a built-in SQL feature, + // but rather a utility class used in CODDTest's oracle construction + // to substitute expressions with their corresponding constant values. + public static class SQLite3ExpressionBag extends SQLite3Expression { + private SQLite3Expression innerExpr; + + public SQLite3ExpressionBag(SQLite3Expression innerExpr) { + this.innerExpr = innerExpr; + } + + public void updateInnerExpr(SQLite3Expression innerExpr) { + this.innerExpr = innerExpr; + } + + public SQLite3Expression getInnerExpr() { + return innerExpr; + } + + @Override + public SQLite3CollateSequence getExplicitCollateSequence() { + return null; + } + + } + + public static class SQLite3Typeof extends SQLite3Expression { + private final SQLite3Expression innerExpr; + + public SQLite3Typeof(SQLite3Expression innerExpr) { + this.innerExpr = innerExpr; + } + + public SQLite3Expression getInnerExpr() { + return innerExpr; + } + + @Override + public SQLite3CollateSequence getExplicitCollateSequence() { + return null; + } + + } + + public static class SQLite3ResultMap extends SQLite3Expression { + private final SQLite3Values values; + private final List columns; + private final List summary; + private final SQLite3DataType summaryDataType; + + public SQLite3ResultMap(SQLite3Values values, List columns, List summary, + SQLite3DataType summaryDataType) { + this.values = values; + this.columns = columns; + this.summary = summary; + this.summaryDataType = summaryDataType; + + Map> vs = values.getValues(); + if (vs.get(vs.keySet().iterator().next()).size() != summary.size()) { + throw new AssertionError(); + } + } + + public SQLite3Values getValues() { + return this.values; + } + + public List getColumns() { + return this.columns; + } + + public List getSummary() { + return this.summary; + } + + public SQLite3DataType getSummaryDataType() { + return this.summaryDataType; + } + + @Override + public SQLite3CollateSequence getExplicitCollateSequence() { + return null; + } + + } } diff --git a/src/sqlancer/sqlite3/ast/SQLite3Select.java b/src/sqlancer/sqlite3/ast/SQLite3Select.java index bff7fb775..176057e8a 100644 --- a/src/sqlancer/sqlite3/ast/SQLite3Select.java +++ b/src/sqlancer/sqlite3/ast/SQLite3Select.java @@ -4,9 +4,16 @@ import java.util.Collections; import java.util.List; +import sqlancer.IgnoreMeException; +import sqlancer.common.ast.newast.Select; +import sqlancer.sqlite3.SQLite3Visitor; +import sqlancer.sqlite3.ast.SQLite3Expression.Join; +import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Column; import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Column.SQLite3CollateSequence; +import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Table; -public class SQLite3Select extends SQLite3Expression { +public class SQLite3Select extends SQLite3Expression + implements Select { private SelectType fromOptions = SelectType.ALL; private List fromList = Collections.emptyList(); @@ -18,6 +25,7 @@ public class SQLite3Select extends SQLite3Expression { private List fetchColumns = Collections.emptyList(); private List joinStatements = Collections.emptyList(); private SQLite3Expression havingClause; + private SQLite3WithClause withClause; public SQLite3Select() { } @@ -36,6 +44,7 @@ public SQLite3Select(SQLite3Select other) { joinStatements.add(new Join(j)); } havingClause = other.havingClause; + withClause = other.withClause; } public enum SelectType { @@ -46,10 +55,6 @@ public void setSelectType(SelectType fromOptions) { this.setFromOptions(fromOptions); } - public void setFromTables(List fromTables) { - this.setFromList(fromTables); - } - public SelectType getFromOptions() { return fromOptions; } @@ -58,66 +63,82 @@ public void setFromOptions(SelectType fromOptions) { this.fromOptions = fromOptions; } + @Override public List getFromList() { return fromList; } + @Override public void setFromList(List fromList) { this.fromList = fromList; } + @Override public SQLite3Expression getWhereClause() { return whereClause; } + @Override public void setWhereClause(SQLite3Expression whereClause) { this.whereClause = whereClause; } + @Override public void setGroupByClause(List groupByClause) { this.groupByClause = groupByClause; } + @Override public List getGroupByClause() { return groupByClause; } + @Override public void setLimitClause(SQLite3Expression limitClause) { this.limitClause = limitClause; } + @Override public SQLite3Expression getLimitClause() { return limitClause; } - public List getOrderByClause() { + @Override + public List getOrderByClauses() { return orderByClause; } - public void setOrderByExpressions(List orderBy) { + @Override + public void setOrderByClauses(List orderBy) { this.orderByClause = orderBy; } + @Override public void setOffsetClause(SQLite3Expression offsetClause) { this.offsetClause = offsetClause; } + @Override public SQLite3Expression getOffsetClause() { return offsetClause; } + @Override public void setFetchColumns(List fetchColumns) { this.fetchColumns = fetchColumns; } + @Override public List getFetchColumns() { return fetchColumns; } + @Override public void setJoinClauses(List joinStatements) { this.joinStatements = joinStatements; } + @Override public List getJoinClauses() { return joinStatements; } @@ -128,13 +149,50 @@ public SQLite3CollateSequence getExplicitCollateSequence() { return null; } + @Override public void setHavingClause(SQLite3Expression havingClause) { this.havingClause = havingClause; } + @Override public SQLite3Expression getHavingClause() { assert orderByClause != null; return havingClause; } + @Override + public String asString() { + return SQLite3Visitor.asString(this); + } + + public void setWithClause(SQLite3WithClause withClause) { + this.withClause = withClause; + } + + public void updateWithClauseRight(SQLite3Expression withClauseRight) { + this.withClause.updateRight(withClauseRight); + } + + public SQLite3Expression getWithClause() { + return this.withClause; + } + + // This method is used in CODDTest to test subquery by replacing a table name + // in the SELECT clause with a derived table expression. + public void replaceFromTable(String tableName, SQLite3Expression newFromExpression) { + int replaceIdx = -1; + for (int i = 0; i < fromList.size(); ++i) { + SQLite3Expression f = fromList.get(i); + if (f instanceof SQLite3TableReference) { + SQLite3TableReference tableRef = (SQLite3TableReference) f; + if (tableRef.getTable().getName().equals(tableName)) { + replaceIdx = i; + } + } + } + if (replaceIdx == -1) { + throw new IgnoreMeException(); + } + fromList.set(replaceIdx, newFromExpression); + } } diff --git a/src/sqlancer/sqlite3/gen/SQLite3ColumnBuilder.java b/src/sqlancer/sqlite3/gen/SQLite3ColumnBuilder.java index e07533ed8..4e83fbca5 100644 --- a/src/sqlancer/sqlite3/gen/SQLite3ColumnBuilder.java +++ b/src/sqlancer/sqlite3/gen/SQLite3ColumnBuilder.java @@ -5,7 +5,7 @@ import sqlancer.Randomly; import sqlancer.sqlite3.SQLite3GlobalState; -import sqlancer.sqlite3.SQLite3Options.SQLite3OracleFactory; +import sqlancer.sqlite3.SQLite3OracleFactory; import sqlancer.sqlite3.SQLite3Visitor; import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Column; @@ -51,7 +51,7 @@ public String createColumn(String columnName, SQLite3GlobalState globalState, Li if (Randomly.getBooleanWithRatherLowProbability()) { List constraints = Randomly.subset(Constraints.values()); if (!Randomly.getBooleanWithSmallProbability() - || globalState.getDbmsSpecificOptions().testGeneratedColumns) { + || !globalState.getDbmsSpecificOptions().testGeneratedColumns) { constraints.remove(Constraints.GENERATED_AS); } if (constraints.contains(Constraints.GENERATED_AS)) { diff --git a/src/sqlancer/sqlite3/gen/SQLite3CreateVirtualRtreeTabelGenerator.java b/src/sqlancer/sqlite3/gen/SQLite3CreateVirtualRtreeTabelGenerator.java index 08effcdc9..61f821382 100644 --- a/src/sqlancer/sqlite3/gen/SQLite3CreateVirtualRtreeTabelGenerator.java +++ b/src/sqlancer/sqlite3/gen/SQLite3CreateVirtualRtreeTabelGenerator.java @@ -3,6 +3,7 @@ import java.util.ArrayList; import java.util.List; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.DBMSCommon; import sqlancer.common.query.ExpectedErrors; @@ -15,6 +16,14 @@ public final class SQLite3CreateVirtualRtreeTabelGenerator { private SQLite3CreateVirtualRtreeTabelGenerator() { } + public static SQLQueryAdapter createRandomTableStatement(SQLite3GlobalState globalState) { + if (globalState.getSchema().getTables().getTables() + .size() > globalState.getDbmsSpecificOptions().maxNumTables) { + throw new IgnoreMeException(); + } + return createTableStatement(globalState.getSchema().getFreeRtreeTableName(), globalState); + } + public static SQLQueryAdapter createTableStatement(String rTreeTableName, SQLite3GlobalState globalState) { ExpectedErrors errors = new ExpectedErrors(); List columns = new ArrayList<>(); diff --git a/src/sqlancer/sqlite3/gen/SQLite3ExplainGenerator.java b/src/sqlancer/sqlite3/gen/SQLite3ExplainGenerator.java index e9ede64d9..8c00b3231 100644 --- a/src/sqlancer/sqlite3/gen/SQLite3ExplainGenerator.java +++ b/src/sqlancer/sqlite3/gen/SQLite3ExplainGenerator.java @@ -26,4 +26,11 @@ public static SQLQueryAdapter explain(SQLite3GlobalState globalState) throws Exc return new SQLQueryAdapter(sb.toString(), query.getExpectedErrors()); } + public static String explain(String selectStr) throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append("EXPLAIN QUERY PLAN "); + sb.append(selectStr); + return sb.toString(); + } + } diff --git a/src/sqlancer/sqlite3/gen/SQLite3ExpressionGenerator.java b/src/sqlancer/sqlite3/gen/SQLite3ExpressionGenerator.java index d21f00af3..e4f0741f9 100644 --- a/src/sqlancer/sqlite3/gen/SQLite3ExpressionGenerator.java +++ b/src/sqlancer/sqlite3/gen/SQLite3ExpressionGenerator.java @@ -8,6 +8,9 @@ import sqlancer.Randomly; import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.schema.AbstractTables; import sqlancer.sqlite3.SQLite3GlobalState; import sqlancer.sqlite3.ast.SQLite3Aggregate; import sqlancer.sqlite3.ast.SQLite3Aggregate.SQLite3AggregateFunction; @@ -31,12 +34,14 @@ import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixText; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixUnaryOperation; import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixUnaryOperation.PostfixUnaryOperator; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3TableReference; import sqlancer.sqlite3.ast.SQLite3Expression.Sqlite3BinaryOperation; import sqlancer.sqlite3.ast.SQLite3Expression.Sqlite3BinaryOperation.BinaryOperator; import sqlancer.sqlite3.ast.SQLite3Expression.TypeLiteral; import sqlancer.sqlite3.ast.SQLite3Function; import sqlancer.sqlite3.ast.SQLite3Function.ComputableFunction; import sqlancer.sqlite3.ast.SQLite3RowValueExpression; +import sqlancer.sqlite3.ast.SQLite3Select; import sqlancer.sqlite3.ast.SQLite3UnaryOperation; import sqlancer.sqlite3.ast.SQLite3UnaryOperation.UnaryOperator; import sqlancer.sqlite3.oracle.SQLite3RandomQuerySynthesizer; @@ -45,12 +50,15 @@ import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3RowValue; import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Table; -public class SQLite3ExpressionGenerator implements ExpressionGenerator { +public class SQLite3ExpressionGenerator implements ExpressionGenerator, + NoRECGenerator, + TLPWhereGenerator { private SQLite3RowValue rw; private final SQLite3GlobalState globalState; private boolean tryToGenerateKnownResult; private List columns = Collections.emptyList(); + private List targetTables; private final Randomly r; private boolean deterministicOnly; private boolean allowMatchClause; @@ -63,6 +71,7 @@ public SQLite3ExpressionGenerator(SQLite3ExpressionGenerator other) { this.globalState = other.globalState; this.tryToGenerateKnownResult = other.tryToGenerateKnownResult; this.columns = new ArrayList<>(other.columns); + this.targetTables = other.targetTables; this.r = other.r; this.deterministicOnly = other.deterministicOnly; this.allowMatchClause = other.allowMatchClause; @@ -126,6 +135,7 @@ public static SQLite3Expression getRandomLiteralValue(SQLite3GlobalState globalS return new SQLite3ExpressionGenerator(globalState).getRandomLiteralValueInternal(globalState.getRandomly()); } + @Override public List generateOrderBys() { List expressions = new ArrayList<>(); for (int i = 0; i < Randomly.smallNumber() + 1; i++) { @@ -139,19 +149,25 @@ public List getRandomJoinClauses(List tables) { if (!globalState.getDbmsSpecificOptions().testJoins) { return joinStatements; } + List options = new ArrayList<>(Arrays.asList(JoinType.values())); if (Randomly.getBoolean() && tables.size() > 1) { int nrJoinClauses = (int) Randomly.getNotCachedInteger(0, tables.size()); + // Natural join is incompatible with other joins + // because it needs unique column names + // while other joins will produce duplicate column names + if (nrJoinClauses > 1) { + options.remove(JoinType.NATURAL); + } for (int i = 0; i < nrJoinClauses; i++) { SQLite3Expression joinClause = generateExpression(); SQLite3Table table = Randomly.fromList(tables); tables.remove(table); - JoinType options; - options = Randomly.fromOptions(JoinType.INNER, JoinType.CROSS, JoinType.OUTER, JoinType.NATURAL); - if (options == JoinType.NATURAL) { + JoinType selectedOption = Randomly.fromList(options); + if (selectedOption == JoinType.NATURAL) { // NATURAL joins do not have an ON clause joinClause = null; } - Join j = new SQLite3Expression.Join(table, joinClause, options); + Join j = new SQLite3Expression.Join(table, joinClause, selectedOption); joinStatements.add(j); } @@ -547,6 +563,12 @@ private SQLite3Expression getFunction(SQLite3GlobalState globalState, int depth) nrArgs += Randomly.smallNumber(); } List expressions = randomFunction.generateArguments(nrArgs, depth + 1, this); + // The second argument of LIKELIHOOD must be a float number within 0.0 -1.0 + if (randomFunction == AnyFunction.LIKELIHOOD) { + SQLite3Expression lastArg = SQLite3Constant.createRealConstant(Randomly.getPercentage()); + expressions.remove(expressions.size() - 1); + expressions.add(lastArg); + } return new SQLite3Expression.Function(randomFunction.toString(), expressions.toArray(new SQLite3Expression[0])); } @@ -604,6 +626,11 @@ private SQLite3Expression getComputableFunction(int depth) { args[i] = new SQLite3Distinct(args[i]); } } + // The second argument of LIKELIHOOD must be a float number within 0.0 -1.0 + if (func == ComputableFunction.LIKELIHOOD) { + SQLite3Expression lastArg = SQLite3Constant.createRealConstant(Randomly.getPercentage()); + args[args.length - 1] = lastArg; + } return new SQLite3Function(func, args); } @@ -685,4 +712,82 @@ public SQLite3Expression generateResultKnownExpression() { return expr; } + @Override + public SQLite3ExpressionGenerator setTablesAndColumns(AbstractTables targetTables) { + SQLite3ExpressionGenerator gen = new SQLite3ExpressionGenerator(this); + gen.targetTables = targetTables.getTables(); + gen.columns = targetTables.getColumns(); + return gen; + } + + @Override + public SQLite3Expression generateBooleanExpression() { + return generateExpression(); + } + + @Override + public SQLite3Select generateSelect() { + return new SQLite3Select(); + } + + @Override + public List getRandomJoinClauses() { + return getRandomJoinClauses(targetTables); + } + + @Override + public List getTableRefs() { + List tableRefs = new ArrayList<>(); + for (SQLite3Table t : targetTables) { + SQLite3TableReference tableRef; + if (Randomly.getBooleanWithSmallProbability() && !globalState.getSchema().getIndexNames().isEmpty()) { + tableRef = new SQLite3TableReference(globalState.getSchema().getRandomIndexOrBailout(), t); + } else { + tableRef = new SQLite3TableReference(t); + } + tableRefs.add(tableRef); + } + return tableRefs; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + List columns = new ArrayList<>(); + if (shouldCreateDummy && Randomly.getBoolean()) { + columns.add(new SQLite3ColumnName(SQLite3Column.createDummy("*"), null)); + } else { + columns = Randomly.nonEmptySubset(this.columns).stream().map(c -> new SQLite3ColumnName(c, null)) + .collect(Collectors.toList()); + } + return columns; + } + + @Override + public String generateOptimizedQueryString(SQLite3Select select, SQLite3Expression whereCondition, + boolean shouldUseAggregate) { + if (Randomly.getBoolean()) { + select.setOrderByClauses(generateOrderBys()); + } + if (shouldUseAggregate) { + select.setFetchColumns(Arrays.asList(new SQLite3Aggregate(Collections.emptyList(), + SQLite3Aggregate.SQLite3AggregateFunction.COUNT_ALL))); + } else { + SQLite3ColumnName aggr = new SQLite3ColumnName(SQLite3Column.createDummy("*"), null); + select.setFetchColumns(Arrays.asList(aggr)); + } + select.setWhereClause(whereCondition); + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(SQLite3Select select, SQLite3Expression whereCondition) { + SQLite3PostfixUnaryOperation isTrue = new SQLite3PostfixUnaryOperation(PostfixUnaryOperator.IS_TRUE, + whereCondition); + SQLite3PostfixText asText = new SQLite3PostfixText(isTrue, " as count", null); + select.setFetchColumns(Arrays.asList(asText)); + select.setWhereClause(null); + + return "SELECT SUM(count) FROM (" + select.asString() + ")"; + } } diff --git a/src/sqlancer/sqlite3/gen/ddl/SQLite3CreateVirtualFTSTableGenerator.java b/src/sqlancer/sqlite3/gen/ddl/SQLite3CreateVirtualFTSTableGenerator.java index a0f79fa7e..3ae9d16d5 100644 --- a/src/sqlancer/sqlite3/gen/ddl/SQLite3CreateVirtualFTSTableGenerator.java +++ b/src/sqlancer/sqlite3/gen/ddl/SQLite3CreateVirtualFTSTableGenerator.java @@ -4,10 +4,12 @@ import java.util.Arrays; import java.util.List; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.DBMSCommon; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.sqlite3.SQLite3GlobalState; public class SQLite3CreateVirtualFTSTableGenerator { @@ -20,6 +22,14 @@ public SQLite3CreateVirtualFTSTableGenerator(String tableName, Randomly r) { this.r = r; } + public static SQLQueryAdapter createRandomTableStatement(SQLite3GlobalState globalState) { + if (globalState.getSchema().getTables().getTables() + .size() > globalState.getDbmsSpecificOptions().maxNumTables) { + throw new IgnoreMeException(); + } + return createTableStatement(globalState.getSchema().getFreeVirtualTableName(), globalState.getRandomly()); + } + public static SQLQueryAdapter createTableStatement(String tableName, Randomly r) { return new SQLite3CreateVirtualFTSTableGenerator(tableName, r).create(); } diff --git a/src/sqlancer/sqlite3/gen/ddl/SQLite3IndexGenerator.java b/src/sqlancer/sqlite3/gen/ddl/SQLite3IndexGenerator.java index 0a33016fd..c81d895ac 100644 --- a/src/sqlancer/sqlite3/gen/ddl/SQLite3IndexGenerator.java +++ b/src/sqlancer/sqlite3/gen/ddl/SQLite3IndexGenerator.java @@ -3,6 +3,7 @@ import java.sql.SQLException; import java.util.List; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; @@ -23,6 +24,9 @@ public class SQLite3IndexGenerator { private final SQLite3GlobalState globalState; public static SQLQueryAdapter insertIndex(SQLite3GlobalState globalState) throws SQLException { + if (globalState.getSchema().getIndexNames().size() >= globalState.getDbmsSpecificOptions().maxNumIndexes) { + throw new IgnoreMeException(); + } return new SQLite3IndexGenerator(globalState).create(); } diff --git a/src/sqlancer/sqlite3/gen/ddl/SQLite3TableGenerator.java b/src/sqlancer/sqlite3/gen/ddl/SQLite3TableGenerator.java index cc4731911..7847208a3 100644 --- a/src/sqlancer/sqlite3/gen/ddl/SQLite3TableGenerator.java +++ b/src/sqlancer/sqlite3/gen/ddl/SQLite3TableGenerator.java @@ -5,13 +5,14 @@ import java.util.List; import java.util.stream.Collectors; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.DBMSCommon; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.sqlite3.SQLite3Errors; import sqlancer.sqlite3.SQLite3GlobalState; -import sqlancer.sqlite3.SQLite3Options.SQLite3OracleFactory; +import sqlancer.sqlite3.SQLite3OracleFactory; import sqlancer.sqlite3.gen.SQLite3ColumnBuilder; import sqlancer.sqlite3.gen.SQLite3Common; import sqlancer.sqlite3.schema.SQLite3Schema; @@ -46,6 +47,14 @@ public SQLite3TableGenerator(String tableName, SQLite3GlobalState globalState) { this.existingSchema = globalState.getSchema(); } + public static SQLQueryAdapter createRandomTableStatement(SQLite3GlobalState globalState) { + if (globalState.getSchema().getTables().getTables() + .size() > globalState.getDbmsSpecificOptions().maxNumTables) { + throw new IgnoreMeException(); + } + return createTableStatement(globalState.getSchema().getFreeTableName(), globalState); + } + public static SQLQueryAdapter createTableStatement(String tableName, SQLite3GlobalState globalState) { SQLite3TableGenerator sqLite3TableGenerator = new SQLite3TableGenerator(tableName, globalState); sqLite3TableGenerator.start(); diff --git a/src/sqlancer/sqlite3/gen/ddl/SQLite3ViewGenerator.java b/src/sqlancer/sqlite3/gen/ddl/SQLite3ViewGenerator.java index ffa3af6cd..bc605a015 100644 --- a/src/sqlancer/sqlite3/gen/ddl/SQLite3ViewGenerator.java +++ b/src/sqlancer/sqlite3/gen/ddl/SQLite3ViewGenerator.java @@ -2,13 +2,14 @@ import java.sql.SQLException; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.DBMSCommon; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.sqlite3.SQLite3Errors; import sqlancer.sqlite3.SQLite3GlobalState; -import sqlancer.sqlite3.SQLite3Options.SQLite3OracleFactory; +import sqlancer.sqlite3.SQLite3OracleFactory; import sqlancer.sqlite3.SQLite3Visitor; import sqlancer.sqlite3.ast.SQLite3Expression; import sqlancer.sqlite3.ast.SQLite3Select; @@ -29,9 +30,13 @@ public static SQLQueryAdapter dropView(SQLite3GlobalState globalState) { } public static SQLQueryAdapter generate(SQLite3GlobalState globalState) throws SQLException { + if (globalState.getSchema().getTables().getTables() + .size() >= globalState.getDbmsSpecificOptions().maxNumTables) { + throw new IgnoreMeException(); + } StringBuilder sb = new StringBuilder(); sb.append("CREATE"); - if (Randomly.getBoolean()) { + if (globalState.getDbmsSpecificOptions().testTempTables && Randomly.getBoolean()) { sb.append(" "); sb.append(Randomly.fromOptions("TEMP", "TEMPORARY")); } diff --git a/src/sqlancer/sqlite3/gen/dml/SQLite3UpdateGenerator.java b/src/sqlancer/sqlite3/gen/dml/SQLite3UpdateGenerator.java index 791b7f5eb..5a17ad339 100644 --- a/src/sqlancer/sqlite3/gen/dml/SQLite3UpdateGenerator.java +++ b/src/sqlancer/sqlite3/gen/dml/SQLite3UpdateGenerator.java @@ -4,7 +4,7 @@ import java.util.stream.Collectors; import sqlancer.Randomly; -import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.gen.AbstractUpdateGenerator; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.sqlite3.SQLite3Errors; import sqlancer.sqlite3.SQLite3GlobalState; @@ -14,12 +14,10 @@ import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Column; import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Table; -public class SQLite3UpdateGenerator { +public class SQLite3UpdateGenerator extends AbstractUpdateGenerator { - private final StringBuilder sb = new StringBuilder(); - private final Randomly r; - private final ExpectedErrors errors = new ExpectedErrors(); private final SQLite3GlobalState globalState; + private final Randomly r; public SQLite3UpdateGenerator(SQLite3GlobalState globalState, Randomly r) { this.globalState = globalState; @@ -34,10 +32,11 @@ public static SQLQueryAdapter updateRow(SQLite3GlobalState globalState) { public static SQLQueryAdapter updateRow(SQLite3GlobalState globalState, SQLite3Table table) { SQLite3UpdateGenerator generator = new SQLite3UpdateGenerator(globalState, globalState.getRandomly()); - return generator.update(table); + return generator.generate(table); } - private SQLQueryAdapter update(SQLite3Table table) { + private SQLQueryAdapter generate(SQLite3Table table) { + List columnsToUpdate = Randomly.nonEmptySubsetPotentialDuplicates(table.getColumns()); sb.append("UPDATE "); if (Randomly.getBoolean()) { sb.append("OR IGNORE "); @@ -55,7 +54,6 @@ private SQLQueryAdapter update(SQLite3Table table) { sb.append(table.getName()); sb.append(" SET "); - List columnsToUpdate = Randomly.nonEmptySubsetPotentialDuplicates(table.getColumns()); if (Randomly.getBoolean()) { sb.append("("); sb.append(columnsToUpdate.stream().map(c -> c.getName()).collect(Collectors.joining(", "))); @@ -66,19 +64,12 @@ private SQLQueryAdapter update(SQLite3Table table) { if (i != 0) { sb.append(", "); } - getToUpdateValue(columnsToUpdate, i); + updateValue(columnsToUpdate.get(i)); } sb.append(")"); // row values } else { - for (int i = 0; i < columnsToUpdate.size(); i++) { - if (i != 0) { - sb.append(", "); - } - sb.append(columnsToUpdate.get(i).getName()); - sb.append(" = "); - getToUpdateValue(columnsToUpdate, i); - } + updateColumns(columnsToUpdate); } if (Randomly.getBoolean()) { @@ -111,8 +102,9 @@ private SQLQueryAdapter update(SQLite3Table table) { } - private void getToUpdateValue(List columnsToUpdate, int i) { - if (columnsToUpdate.get(i).isIntegerPrimaryKey()) { + @Override + protected void updateValue(SQLite3Column column) { + if (column.isIntegerPrimaryKey()) { sb.append(SQLite3Visitor.asString(SQLite3Constant.createIntConstant(r.getInteger()))); } else { sb.append(SQLite3Visitor.asString(SQLite3ExpressionGenerator.getRandomLiteralValue(globalState))); diff --git a/src/sqlancer/sqlite3/oracle/SQLite3CODDTestOracle.java b/src/sqlancer/sqlite3/oracle/SQLite3CODDTestOracle.java new file mode 100644 index 000000000..532709676 --- /dev/null +++ b/src/sqlancer/sqlite3/oracle/SQLite3CODDTestOracle.java @@ -0,0 +1,981 @@ +package sqlancer.sqlite3.oracle; + +import java.math.BigDecimal; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Main; +import sqlancer.Randomly; +import sqlancer.Reproducer; +import sqlancer.common.oracle.CODDTestBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.sqlite3.SQLite3Errors; +import sqlancer.sqlite3.SQLite3GlobalState; +import sqlancer.sqlite3.SQLite3Provider; +import sqlancer.sqlite3.SQLite3Visitor; +import sqlancer.sqlite3.ast.SQLite3Aggregate; +import sqlancer.sqlite3.ast.SQLite3Constant; +import sqlancer.sqlite3.ast.SQLite3Constant.SQLite3TextConstant; +import sqlancer.sqlite3.ast.SQLite3Expression; +import sqlancer.sqlite3.ast.SQLite3Expression.InOperation; +import sqlancer.sqlite3.ast.SQLite3Expression.Join; +import sqlancer.sqlite3.ast.SQLite3Expression.Join.JoinType; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Alias; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ColumnName; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Exist; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ExpressionBag; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3OrderingTerm; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3OrderingTerm.Ordering; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixText; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ResultMap; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3TableAndColumnRef; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3TableReference; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Typeof; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3Values; +import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3WithClause; +import sqlancer.sqlite3.ast.SQLite3Expression.Sqlite3BinaryOperation.BinaryOperator; +import sqlancer.sqlite3.ast.SQLite3Select; +import sqlancer.sqlite3.gen.SQLite3Common; +import sqlancer.sqlite3.gen.SQLite3ExpressionGenerator; +import sqlancer.sqlite3.schema.SQLite3DataType; +import sqlancer.sqlite3.schema.SQLite3Schema; +import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Column; +import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Table; +import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Tables; + +public class SQLite3CODDTestOracle extends CODDTestBase implements TestOracle { + + private final SQLite3Schema s; + private SQLite3ExpressionGenerator gen; + private Reproducer reproducer; + + private static final String TEMP_TABLE_NAME = "temp_table"; + + private SQLite3Expression foldedExpr; + private SQLite3Expression constantResOfFoldedExpr; + + private List tablesFromOuterContext = new ArrayList<>(); + private List joinsInExpr; + + Map> auxiliaryQueryResult = new HashMap<>(); + Map> selectResult = new HashMap<>(); + + Boolean useSubqueryAsFoldedExpr; + Boolean useCorrelatedSubqueryAsFoldedExpr; + + public SQLite3CODDTestOracle(SQLite3GlobalState globalState) { + super(globalState); + this.s = globalState.getSchema(); + SQLite3Errors.addExpectedExpressionErrors(errors); + SQLite3Errors.addMatchQueryErrors(errors); + SQLite3Errors.addQueryErrors(errors); + // errors.add("misuse of aggregate"); + // errors.add("misuse of window function"); + // errors.add("second argument to nth_value must be a positive integer"); + // errors.add("no such table"); + // errors.add("no query solution"); + // errors.add("unable to use function MATCH in the requested context"); + // errors.add("[SQLITE_ERROR] SQL error or missing database (unrecognized token:"); + } + + @Override + public void check() throws SQLException { + reproducer = null; + + joinsInExpr = null; + tablesFromOuterContext.clear(); + + useSubqueryAsFoldedExpr = useSubquery(); + useCorrelatedSubqueryAsFoldedExpr = useCorrelatedSubquery(); + + SQLite3Select auxiliaryQuery = null; + if (useSubqueryAsFoldedExpr) { + if (useCorrelatedSubqueryAsFoldedExpr) { + auxiliaryQuery = genSelectWithCorrelatedSubquery(); + auxiliaryQueryString = SQLite3Visitor.asString(auxiliaryQuery); + + auxiliaryQueryResult.putAll(selectResult); + } else { + auxiliaryQuery = genSelectExpression(null, null); + auxiliaryQueryString = SQLite3Visitor.asString(auxiliaryQuery); + auxiliaryQueryResult = getQueryResult(auxiliaryQueryString, state); + } + } else { + auxiliaryQuery = genSimpleSelect(); + auxiliaryQueryString = SQLite3Visitor.asString(auxiliaryQuery); + + auxiliaryQueryResult.putAll(selectResult); + } + + SQLite3Select originalQuery = null; + + Map> foldedResult = new HashMap<>(); + Map> originalResult = new HashMap<>(); + + // dependent expression + if (!useSubqueryAsFoldedExpr || useSubqueryAsFoldedExpr && useCorrelatedSubqueryAsFoldedExpr) { + // original query + SQLite3ExpressionBag specificCondition = new SQLite3ExpressionBag(this.foldedExpr); + originalQuery = this.genSelectExpression(null, specificCondition); + originalQueryString = SQLite3Visitor.asString(originalQuery); + originalResult = getQueryResult(originalQueryString, state); + + // folded query + specificCondition.updateInnerExpr(this.constantResOfFoldedExpr); + foldedQueryString = SQLite3Visitor.asString(originalQuery); + foldedResult = getQueryResult(foldedQueryString, state); + } else if (auxiliaryQueryResult.isEmpty() + || auxiliaryQueryResult.get(auxiliaryQueryResult.keySet().iterator().next()).isEmpty()) { + // independent expression + // empty result, put the inner query in (NOT) EXIST + boolean isNegated = !Randomly.getBoolean(); + // original query + SQLite3Exist existExpr = new SQLite3Exist(new SQLite3Select(auxiliaryQuery), isNegated); + SQLite3ExpressionBag specificCondition = new SQLite3ExpressionBag(existExpr); + + originalQuery = this.genSelectExpression(null, specificCondition); + originalQueryString = SQLite3Visitor.asString(originalQuery); + originalResult = getQueryResult(originalQueryString, state); + + // folded query + SQLite3Expression equivalentExpr = isNegated ? SQLite3Constant.createTrue() : SQLite3Constant.createFalse(); + specificCondition.updateInnerExpr(equivalentExpr); + foldedQueryString = SQLite3Visitor.asString(originalQuery); + foldedResult = getQueryResult(foldedQueryString, state); + } else if (auxiliaryQueryResult.size() == 1 + && auxiliaryQueryResult.get(auxiliaryQueryResult.keySet().toArray()[0]).size() == 1 + && Randomly.getBoolean()) { + // Scalar Subquery: 1 column and 1 row, consider the inner query as a constant + // original query + SQLite3ExpressionBag specificCondition = new SQLite3ExpressionBag(auxiliaryQuery); + originalQuery = this.genSelectExpression(null, specificCondition); + originalQueryString = SQLite3Visitor.asString(originalQuery); + originalResult = getQueryResult(originalQueryString, state); + + // folded query + SQLite3Expression equivalentExpr = auxiliaryQueryResult.get(auxiliaryQueryResult.keySet().toArray()[0]) + .get(0); + specificCondition.updateInnerExpr(equivalentExpr); + foldedQueryString = SQLite3Visitor.asString(originalQuery); + foldedResult = getQueryResult(foldedQueryString, state); + } else if (auxiliaryQueryResult.size() == 1 && Randomly.getBooleanWithRatherLowProbability() + && enableInOperator()) { + // one column + // original query + List columns = s.getRandomTableNonEmptyTables().getColumns(); + SQLite3ColumnName selectedColumn = new SQLite3ColumnName(Randomly.fromList(columns), null); + SQLite3Table selectedTable = selectedColumn.getColumn().getTable(); + InOperation inOperation = new InOperation(selectedColumn, new SQLite3Select(auxiliaryQuery)); + SQLite3ExpressionBag specificCondition = new SQLite3ExpressionBag(inOperation); + + originalQuery = this.genSelectExpression(selectedTable, specificCondition); + originalQueryString = SQLite3Visitor.asString(originalQuery); + originalResult = getQueryResult(originalQueryString, state); + // folded query + // can not use IN VALUES here, because there is no affinity for the right operand of IN when right operand + // is a list + try { + SQLite3Table t = this.createTemporaryTable(auxiliaryQuery, "intable"); + SQLite3TableReference equivalentTable = new SQLite3TableReference(t); + inOperation = new InOperation(selectedColumn, equivalentTable); + specificCondition.updateInnerExpr(inOperation); + foldedQueryString = SQLite3Visitor.asString(originalQuery); + foldedResult = getQueryResult(foldedQueryString, state); + } finally { + dropTemporaryTable("intable"); + } + } else { + // There is not `ANY` and `ALL` operator in SQLite3 + // Row Subquery + // original query + SQLite3Table temporaryTable = this.genTemporaryTable(auxiliaryQuery, SQLite3CODDTestOracle.TEMP_TABLE_NAME); + originalQuery = this.genSelectExpression(temporaryTable, null); + SQLite3TableAndColumnRef tableAndColumnRef = new SQLite3TableAndColumnRef(temporaryTable); + SQLite3WithClause withClause = new SQLite3WithClause(tableAndColumnRef, new SQLite3Select(auxiliaryQuery)); + originalQuery.setWithClause(withClause); + originalQueryString = SQLite3Visitor.asString(originalQuery); + originalResult = getQueryResult(originalQueryString, state); + // folded query + if (Randomly.getBoolean() && this.enableCommonTableExpression()) { + // there are too many false positives + // common table expression + // folded query: WITH table AS VALUES () + SQLite3Values values = new SQLite3Values(auxiliaryQueryResult, temporaryTable.getColumns()); + originalQuery.updateWithClauseRight(values); + foldedQueryString = SQLite3Visitor.asString(originalQuery); + foldedResult = getQueryResult(foldedQueryString, state); + } else if (Randomly.getBoolean() && this.enableDerivedTable()) { + // derived table + // folded query: SELECT FROM () AS table + originalQuery.setWithClause(null); + SQLite3TableReference tempTableRef = new SQLite3TableReference(temporaryTable); + SQLite3Alias alias = new SQLite3Alias(new SQLite3Select(auxiliaryQuery), tempTableRef); + originalQuery.replaceFromTable(SQLite3CODDTestOracle.TEMP_TABLE_NAME, alias); + foldedQueryString = SQLite3Visitor.asString(originalQuery); + foldedResult = getQueryResult(foldedQueryString, state); + } else if (this.enableInsert()) { + // there are too many false positives + // folded query: CREATE the table and INSERT INTO table subquery + try { + this.createTemporaryTable(auxiliaryQuery, SQLite3CODDTestOracle.TEMP_TABLE_NAME); + originalQuery.setWithClause(null); + foldedQueryString = SQLite3Visitor.asString(originalQuery); + foldedResult = getQueryResult(foldedQueryString, state); + } finally { + dropTemporaryTable(SQLite3CODDTestOracle.TEMP_TABLE_NAME); + } + } else { + throw new IgnoreMeException(); + } + } + if (foldedResult == null || originalResult == null) { + throw new IgnoreMeException(); + } + if (foldedQueryString.equals(originalQueryString)) { + throw new IgnoreMeException(); + } + if (!compareResult(foldedResult, originalResult)) { + reproducer = null; // TODO + state.getState().getLocalState() + .log(auxiliaryQueryString + ";\n" + foldedQueryString + ";\n" + originalQueryString + ";"); + throw new AssertionError( + auxiliaryQueryResult.toString() + " " + foldedResult.toString() + " " + originalResult.toString()); + } + } + + private SQLite3Select genSelectExpression(SQLite3Table tempTable, SQLite3Expression specificCondition) { + SQLite3Tables randomTables = s.getRandomTableNonEmptyTables(); + if (tempTable != null) { + randomTables.addTable(tempTable); + } + if (!useSubqueryAsFoldedExpr || useSubqueryAsFoldedExpr && useCorrelatedSubqueryAsFoldedExpr) { + for (SQLite3Table t : this.tablesFromOuterContext) { + randomTables.addTable(t); + } + if (this.joinsInExpr != null) { + for (Join j : this.joinsInExpr) { + SQLite3Table t = j.getTable(); + randomTables.removeTable(t); + } + } + } + + List columns = randomTables.getColumns(); + if ((!useSubqueryAsFoldedExpr || useSubqueryAsFoldedExpr && useCorrelatedSubqueryAsFoldedExpr) + && this.joinsInExpr != null) { + for (Join j : this.joinsInExpr) { + SQLite3Table t = j.getTable(); + columns.addAll(t.getColumns()); + } + } + gen = new SQLite3ExpressionGenerator(state).setColumns(columns); + List tables = randomTables.getTables(); + List joinStatements = new ArrayList<>(); + if (!useSubqueryAsFoldedExpr || useSubqueryAsFoldedExpr && useCorrelatedSubqueryAsFoldedExpr) { + if (this.joinsInExpr != null) { + joinStatements.addAll(this.joinsInExpr); + this.joinsInExpr = null; + } + } else if (Randomly.getBoolean()) { + joinStatements = genJoinExpression(gen, tables, + Randomly.getBooleanWithRatherLowProbability() ? specificCondition : null, false); + } + List tableRefs = SQLite3Common.getTableRefs(tables, s); + SQLite3Select select = new SQLite3Select(); + select.setFromList(tableRefs); + if (!joinStatements.isEmpty()) { + select.setJoinClauses(joinStatements); + } + + SQLite3Expression randomWhereCondition = gen.generateExpression(); + SQLite3Expression whereCondition = null; + if (specificCondition != null) { + BinaryOperator operator = BinaryOperator.getRandomOperator(); + whereCondition = new SQLite3Expression.Sqlite3BinaryOperation(randomWhereCondition, specificCondition, + operator); + } else { + whereCondition = randomWhereCondition; + } + select.setWhereClause(whereCondition); + + if (Randomly.getBoolean()) { + select.setOrderByClauses(genOrderBysExpression(gen, + Randomly.getBooleanWithRatherLowProbability() ? specificCondition : null)); + } + + if (Randomly.getBoolean()) { + List selectedColumns = Randomly.nonEmptySubset(columns); + List selectedAlias = new LinkedList<>(); + for (int i = 0; i < selectedColumns.size(); ++i) { + SQLite3ColumnName originalName = new SQLite3ColumnName(selectedColumns.get(i), null); + SQLite3ColumnName aliasName = new SQLite3ColumnName(SQLite3Column.createDummy("c" + i), null); + SQLite3Alias columnAlias = new SQLite3Alias(originalName, aliasName); + selectedAlias.add(columnAlias); + } + select.setFetchColumns(selectedAlias); + } else { + SQLite3ColumnName aggr = new SQLite3ColumnName(Randomly.fromList(columns), null); + SQLite3Provider.mustKnowResult = true; + SQLite3Expression originalName = new SQLite3Aggregate(Arrays.asList(aggr), + SQLite3Aggregate.SQLite3AggregateFunction.getRandom()); + SQLite3ColumnName aliasName = new SQLite3ColumnName(SQLite3Column.createDummy("c0"), null); + SQLite3Alias columnAlias = new SQLite3Alias(originalName, aliasName); + select.setFetchColumns(Arrays.asList(columnAlias)); + if (Randomly.getBooleanWithRatherLowProbability()) { + List groupByClause = genGroupByClause(columns, specificCondition); + select.setGroupByClause(groupByClause); + if (!groupByClause.isEmpty() && Randomly.getBooleanWithRatherLowProbability()) { + select.setHavingClause(genHavingClause(columns, specificCondition)); + } + } + } + return select; + } + + // For expression test + private SQLite3Select genSimpleSelect() { + SQLite3Tables randomTables = s.getRandomTableNonEmptyTables(); + List columns = randomTables.getColumns(); + + gen = new SQLite3ExpressionGenerator(state).setColumns(columns); + List tables = randomTables.getTables(); + tablesFromOuterContext = randomTables.getTables(); + + if (Randomly.getBooleanWithRatherLowProbability()) { + joinsInExpr = genJoinExpression(gen, tables, null, true); + } else { + joinsInExpr = new ArrayList<>(); + } + + List tableRefs = SQLite3Common.getTableRefs(tables, s); + SQLite3Select select = new SQLite3Select(); + select.setFromList(tableRefs); + if (joinsInExpr != null && !joinsInExpr.isEmpty()) { + select.setJoinClauses(joinsInExpr); + } + + SQLite3Expression whereCondition = gen.generateExpression(); + this.foldedExpr = whereCondition; + + List fetchColumns = new ArrayList<>(); + int columnIdx = 0; + for (SQLite3Column c : randomTables.getColumns()) { + SQLite3ColumnName cRef = new SQLite3ColumnName(c, null); + SQLite3ColumnName aliasName = new SQLite3ColumnName(SQLite3Column.createDummy("c" + columnIdx), null); + SQLite3Alias columnAlias = new SQLite3Alias(cRef, aliasName); + fetchColumns.add(columnAlias); + columnIdx++; + } + + // add the expression to fetch clause + SQLite3ColumnName aliasName = new SQLite3ColumnName(SQLite3Column.createDummy("c" + columnIdx), null); + SQLite3Alias columnAlias = new SQLite3Alias(whereCondition, aliasName); + fetchColumns.add(columnAlias); + + select.setFetchColumns(fetchColumns); + + Map> queryRes = null; + try { + queryRes = getQueryResult(SQLite3Visitor.asString(select), state); + } catch (SQLException e) { + if (errors.errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } else { + throw new AssertionError(e.getMessage()); + } + } + if (queryRes.get("c0").size() == 0) { + throw new IgnoreMeException(); + } + + // save the result first + selectResult.clear(); + selectResult.putAll(queryRes); + + // get the summary from results + List summary = queryRes.remove("c" + columnIdx); + + List tempColumnList = new ArrayList<>(); + + for (int i = 0; i < fetchColumns.size() - 1; ++i) { + // do not put the last fetch column to values + SQLite3Alias cAlias = (SQLite3Alias) fetchColumns.get(i); + SQLite3ColumnName cRef = (SQLite3ColumnName) cAlias.getOriginalExpression(); + SQLite3Column column = cRef.getColumn(); + String columnName = SQLite3Visitor.asString(cAlias.getAliasExpression()); + SQLite3Column newColumn = new SQLite3Column(columnName, column.getType(), false, false, null); + tempColumnList.add(newColumn); + } + List columnRef = new ArrayList<>(); + for (SQLite3Column c : randomTables.getColumns()) { + columnRef.add(new SQLite3ColumnName(c, null)); + } + if (tempColumnList.size() != queryRes.size()) { + throw new AssertionError(); + } + SQLite3Values values = new SQLite3Values(queryRes, tempColumnList); + this.constantResOfFoldedExpr = new SQLite3ResultMap(values, columnRef, summary, null); + + return select; + } + + private SQLite3Select genSelectWithCorrelatedSubquery() { + SQLite3Tables outerQueryRandomTables = s.getRandomTableNonEmptyTables(); + SQLite3Tables innerQueryRandomTables = s.getRandomTableNonEmptyTables(); + + List innerQueryFromTables = new ArrayList<>(); + for (SQLite3Table t : innerQueryRandomTables.getTables()) { + if (!outerQueryRandomTables.isContained(t)) { + innerQueryFromTables.add(new SQLite3TableReference(t)); + } + } + for (SQLite3Table t : outerQueryRandomTables.getTables()) { + if (innerQueryRandomTables.isContained(t)) { + innerQueryRandomTables.removeTable(t); + + List newColumns = new ArrayList<>(); + for (SQLite3Column c : t.getColumns()) { + SQLite3Column newColumn = new SQLite3Column(c.getName(), c.getType(), false, null, false); + newColumns.add(newColumn); + } + SQLite3Table newTable = new SQLite3Table(t.getName() + "a", newColumns, null, true, false, false, + false); + for (SQLite3Column c : newColumns) { + c.setTable(newTable); + } + innerQueryRandomTables.addTable(newTable); + + SQLite3Alias alias = new SQLite3Alias(new SQLite3TableReference(t), + new SQLite3TableReference(newTable)); + innerQueryFromTables.add(alias); + } + } + + List innerQueryColumns = new ArrayList<>(); + innerQueryColumns.addAll(innerQueryRandomTables.getColumns()); + innerQueryColumns.addAll(outerQueryRandomTables.getColumns()); + gen = new SQLite3ExpressionGenerator(state).setColumns(innerQueryColumns); + + SQLite3Select innerQuery = new SQLite3Select(); + innerQuery.setFromList(innerQueryFromTables); + + SQLite3Expression innerQueryWhereCondition = gen.generateExpression(); + innerQuery.setWhereClause(innerQueryWhereCondition); + + // use aggregate function in fetch column + SQLite3ColumnName innerQueryAggr = new SQLite3ColumnName(Randomly.fromList(innerQueryRandomTables.getColumns()), + null); + SQLite3Provider.mustKnowResult = true; + SQLite3Expression innerQueryAggrName = new SQLite3Aggregate(Arrays.asList(innerQueryAggr), + SQLite3Aggregate.SQLite3AggregateFunction.getRandom()); + innerQuery.setFetchColumns(Arrays.asList(innerQueryAggrName)); + if (Randomly.getBooleanWithRatherLowProbability()) { + List groupByClause = genGroupByClause(innerQueryColumns, null); + innerQuery.setGroupByClause(groupByClause); + if (!groupByClause.isEmpty() && Randomly.getBooleanWithRatherLowProbability()) { + innerQuery.setHavingClause(genHavingClause(innerQueryColumns, null)); + } + } + + this.foldedExpr = innerQuery; + + // outer query + SQLite3Select outerQuery = new SQLite3Select(); + outerQuery.setFromList(SQLite3Common.getTableRefs(outerQueryRandomTables.getTables(), s)); + tablesFromOuterContext = outerQueryRandomTables.getTables(); + + List outerQueryFetchColumns = new ArrayList<>(); + int columnIdx = 0; + for (SQLite3Column c : outerQueryRandomTables.getColumns()) { + SQLite3ColumnName cRef = new SQLite3ColumnName(c, null); + SQLite3ColumnName aliasName = new SQLite3ColumnName(SQLite3Column.createDummy("c" + columnIdx), null); + SQLite3Alias columnAlias = new SQLite3Alias(cRef, aliasName); + outerQueryFetchColumns.add(columnAlias); + columnIdx++; + } + + // add the expression to fetch clause + SQLite3ColumnName aliasName = new SQLite3ColumnName(SQLite3Column.createDummy("c" + columnIdx), null); + SQLite3Alias columnAlias = new SQLite3Alias(innerQuery, aliasName); + outerQueryFetchColumns.add(columnAlias); + + outerQuery.setFetchColumns(outerQueryFetchColumns); + + originalQueryString = SQLite3Visitor.asString(outerQuery); + + Map> queryRes = null; + try { + queryRes = getQueryResult(originalQueryString, state); + } catch (SQLException e) { + if (errors.errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } else { + throw new AssertionError(e.getMessage()); + } + } + if (queryRes.get("c0").size() == 0) { + throw new IgnoreMeException(); + } + + // save the result first + selectResult.clear(); + selectResult.putAll(queryRes); + + // get the summary from results + List summary = queryRes.remove("c" + columnIdx); + + List tempColumnList = new ArrayList<>(); + + for (int i = 0; i < outerQueryFetchColumns.size() - 1; ++i) { + // do not put the last fetch column to values + SQLite3Alias cAlias = (SQLite3Alias) outerQueryFetchColumns.get(i); + SQLite3ColumnName cRef = (SQLite3ColumnName) cAlias.getOriginalExpression(); + SQLite3Column column = cRef.getColumn(); + String columnName = SQLite3Visitor.asString(cAlias.getAliasExpression()); + SQLite3Column newColumn = new SQLite3Column(columnName, column.getType(), false, false, null); + tempColumnList.add(newColumn); + } + List columnRef = new ArrayList<>(); + for (SQLite3Column c : outerQueryRandomTables.getColumns()) { + columnRef.add(new SQLite3ColumnName(c, null)); + } + if (tempColumnList.size() != queryRes.size()) { + throw new AssertionError(); + } + SQLite3Values values = new SQLite3Values(queryRes, tempColumnList); + this.constantResOfFoldedExpr = new SQLite3ResultMap(values, columnRef, summary, null); + + return outerQuery; + } + + private List genJoinExpression(SQLite3ExpressionGenerator gen, List tables, + SQLite3Expression specificCondition, boolean joinForExperssion) { + List joinStatements = new ArrayList<>(); + if (!state.getDbmsSpecificOptions().testJoins) { + return joinStatements; + } + List options = new ArrayList<>(Arrays.asList(JoinType.values())); + if (Randomly.getBoolean() && tables.size() > 1) { + int nrJoinClauses = (int) Randomly.getNotCachedInteger(0, tables.size()); + // Natural join is incompatible with other joins + // because it needs unique column names + // while other joins will produce duplicate column names + if (nrJoinClauses > 1 || joinForExperssion) { + options.remove(JoinType.NATURAL); + } + for (int i = 0; i < nrJoinClauses; i++) { + SQLite3Expression randomOnCondition = gen.generateExpression(); + SQLite3Expression onCondition = null; + if (specificCondition != null && Randomly.getBooleanWithRatherLowProbability()) { + BinaryOperator operator = BinaryOperator.getRandomOperator(); + onCondition = new SQLite3Expression.Sqlite3BinaryOperation(randomOnCondition, specificCondition, + operator); + } else { + onCondition = randomOnCondition; + } + + SQLite3Table table = Randomly.fromList(tables); + tables.remove(table); + JoinType selectedOption = Randomly.fromList(options); + if (selectedOption == JoinType.NATURAL) { + // NATURAL joins do not have an ON clause + onCondition = null; + } + Join j = new SQLite3Expression.Join(table, onCondition, selectedOption); + joinStatements.add(j); + } + + } + return joinStatements; + } + + private List genOrderBysExpression(SQLite3ExpressionGenerator gen, + SQLite3Expression specificCondition) { + List expressions = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + expressions.add( + genOrderingTerm(gen, Randomly.getBooleanWithRatherLowProbability() ? specificCondition : null)); + } + return expressions; + } + + private SQLite3Expression genOrderingTerm(SQLite3ExpressionGenerator gen, SQLite3Expression specificCondition) { + SQLite3Expression expr = gen.generateExpression(); + if (specificCondition != null && Randomly.getBooleanWithRatherLowProbability()) { + BinaryOperator operator = BinaryOperator.getRandomOperator(); + expr = new SQLite3Expression.Sqlite3BinaryOperation(expr, specificCondition, operator); + } + // COLLATE is potentially already generated + if (Randomly.getBoolean()) { + expr = new SQLite3OrderingTerm(expr, Ordering.getRandomValue()); + } + if (state.getDbmsSpecificOptions().testNullsFirstLast && Randomly.getBoolean()) { + expr = new SQLite3PostfixText(expr, Randomly.fromOptions(" NULLS FIRST", " NULLS LAST"), + null /* expr.getExpectedValue() */) { + @Override + public boolean omitBracketsWhenPrinting() { + return true; + } + }; + } + return expr; + } + + private List genGroupByClause(List columns, SQLite3Expression specificCondition) { + errors.add("GROUP BY term out of range"); + if (Randomly.getBoolean()) { + List collect = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber(); i++) { + SQLite3Expression expr = new SQLite3ExpressionGenerator(state).setColumns(columns).generateExpression(); + if (specificCondition != null && Randomly.getBooleanWithRatherLowProbability()) { + BinaryOperator operator = BinaryOperator.getRandomOperator(); + expr = new SQLite3Expression.Sqlite3BinaryOperation(expr, specificCondition, operator); + } + collect.add(expr); + } + return collect; + } + return Collections.emptyList(); + } + + private SQLite3Expression genHavingClause(List columns, SQLite3Expression specificCondition) { + SQLite3Expression expr = new SQLite3ExpressionGenerator(state).setColumns(columns).generateExpression(); + if (specificCondition != null && Randomly.getBooleanWithRatherLowProbability()) { + BinaryOperator operator = BinaryOperator.getRandomOperator(); + expr = new SQLite3Expression.Sqlite3BinaryOperation(expr, specificCondition, operator); + } + return expr; + } + + private Map> getQueryResult(String queryString, SQLite3GlobalState state) + throws SQLException { + Map> result = new LinkedHashMap<>(); + if (options.logEachSelect()) { + logger.writeCurrentNoLineBreak(queryString); + } + Statement stmt = null; + try { + stmt = this.con.createStatement(); + stmt.setQueryTimeout(600); + ResultSet rs = null; + try { + rs = stmt.executeQuery(queryString); + ResultSetMetaData metaData = rs.getMetaData(); + Integer columnCount = metaData.getColumnCount(); + Map idxNameMap = new HashMap<>(); + for (int i = 1; i <= columnCount; i++) { + result.put("c" + (i - 1), new ArrayList<>()); + idxNameMap.put(i, "c" + (i - 1)); + } + + int resultRows = 0; + while (rs.next()) { + for (int i = 1; i <= columnCount; i++) { + try { + Object value = rs.getObject(i); + SQLite3Constant constant; + if (rs.wasNull()) { + constant = SQLite3Constant.createNullConstant(); + } else if (value instanceof Integer) { + constant = SQLite3Constant.createIntConstant(Long.valueOf((Integer) value)); + } else if (value instanceof Short) { + constant = SQLite3Constant.createIntConstant(Long.valueOf((Short) value)); + } else if (value instanceof Long) { + constant = SQLite3Constant.createIntConstant((Long) value); + } else if (value instanceof Double) { + constant = SQLite3Constant.createRealConstant((double) value); + } else if (value instanceof Float) { + constant = SQLite3Constant.createRealConstant(((Float) value).doubleValue()); + } else if (value instanceof BigDecimal) { + constant = SQLite3Constant.createRealConstant(((BigDecimal) value).doubleValue()); + } else if (value instanceof byte[]) { + constant = SQLite3Constant.createBinaryConstant((byte[]) value); + } else if (value instanceof Boolean) { + constant = SQLite3Constant.createBoolean((boolean) value); + } else if (value instanceof String) { + constant = SQLite3Constant.createTextConstant((String) value); + } else if (value == null) { + constant = SQLite3Constant.createNullConstant(); + } else { + throw new IgnoreMeException(); + } + List v = result.get(idxNameMap.get(i)); + v.add(constant); + } catch (SQLException e) { + System.out.println(e.getMessage()); + throw new IgnoreMeException(); + } + } + ++resultRows; + if (resultRows > 100) { + throw new IgnoreMeException(); + } + } + Main.nrSuccessfulActions.addAndGet(1); + rs.close(); + } catch (SQLException e) { + Main.nrUnsuccessfulActions.addAndGet(1); + if (errors.errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } else { + state.getState().getLocalState().log(queryString); + throw new AssertionError(e.getMessage()); + } + } finally { + if (rs != null) { + rs.close(); + } + } + } finally { + if (stmt != null) { + stmt.close(); + } + } + return result; + } + + private SQLite3Table genTemporaryTable(SQLite3Select select, String tableName) { + List fetchColumns = select.getFetchColumns(); + int columnNumber = fetchColumns.size(); + Map idxTypeMap = getColumnTypeFromSelect(select); + + List databaseColumns = new ArrayList<>(); + for (int i = 0; i < columnNumber; ++i) { + String columnName = "c" + i; + SQLite3Column column = new SQLite3Column(columnName, idxTypeMap.get(i), false, false, null); + databaseColumns.add(column); + } + SQLite3Table table = new SQLite3Table(tableName, databaseColumns, null, false, false, false, false); + for (SQLite3Column c : databaseColumns) { + c.setTable(table); + } + + return table; + } + + private SQLite3Table createTemporaryTable(SQLite3Select select, String tableName) throws SQLException { + String selectString = SQLite3Visitor.asString(select); + Map idxTypeMap = getColumnTypeFromSelect(select); + + Integer columnNumber = idxTypeMap.size(); + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE " + tableName + " ("); + for (int i = 0; i < columnNumber; ++i) { + String columnTypeName = ""; + if (idxTypeMap.get(i) != null) { + switch (idxTypeMap.get(i)) { + case INT: + case TEXT: + case REAL: + columnTypeName = idxTypeMap.get(i).name(); + break; + case BINARY: + columnTypeName = ""; + break; + default: + columnTypeName = ""; + } + } + sb.append("c" + i + " " + columnTypeName); + if (i < columnNumber - 1) { + sb.append(", "); + } + } + sb.append(");"); + String crateTableString = sb.toString(); + if (options.logEachSelect()) { + logger.writeCurrent(crateTableString); + } + Statement stmt = null; + try { + stmt = this.con.createStatement(); + try { + stmt.execute(crateTableString); + Main.nrSuccessfulActions.addAndGet(1); + } catch (SQLException e) { + Main.nrUnsuccessfulActions.addAndGet(1); + throw new IgnoreMeException(); + } + } finally { + if (stmt != null) { + stmt.close(); + } + } + + StringBuilder sb2 = new StringBuilder(); + sb2.append("INSERT INTO " + tableName + " " + selectString); + String insertValueString = sb2.toString(); + if (options.logEachSelect()) { + logger.writeCurrent(insertValueString); + } + stmt = null; + try { + stmt = this.con.createStatement(); + try { + Main.nrSuccessfulActions.addAndGet(1); + stmt.execute(insertValueString); + } catch (SQLException e) { + Main.nrUnsuccessfulActions.addAndGet(1); + throw new IgnoreMeException(); + } + } finally { + if (stmt != null) { + stmt.close(); + } + } + + List databaseColumns = new ArrayList<>(); + for (int i = 0; i < columnNumber; ++i) { + String columnName = "c" + i; + SQLite3Column column = new SQLite3Column(columnName, idxTypeMap.get(i), false, false, null); + databaseColumns.add(column); + } + SQLite3Table table = new SQLite3Table(tableName, databaseColumns, null, false, false, false, false); + for (SQLite3Column c : databaseColumns) { + c.setTable(table); + } + + return table; + } + + private void dropTemporaryTable(String tableName) throws SQLException { + String dropString = "DROP TABLE " + tableName + ";"; + if (options.logEachSelect()) { + logger.writeCurrent(dropString); + } + Statement stmt = null; + try { + stmt = this.con.createStatement(); + try { + stmt.execute(dropString); + Main.nrSuccessfulActions.addAndGet(1); + } catch (SQLException e) { + Main.nrUnsuccessfulActions.addAndGet(1); + throw new IgnoreMeException(); + } + } finally { + if (stmt != null) { + stmt.close(); + } + } + } + + private boolean compareResult(Map> r1, Map> r2) { + if (r1.size() != r2.size()) { + return false; + } + for (Map.Entry> entry : r1.entrySet()) { + String currentKey = entry.getKey(); + if (!r2.containsKey(currentKey)) { + return false; + } + List v1 = entry.getValue(); + List v2 = r2.get(currentKey); + if (v1.size() != v2.size()) { + return false; + } + List v1Value = new ArrayList<>(v1.stream().map(c -> c.toString()).collect(Collectors.toList())); + List v2Value = new ArrayList<>(v2.stream().map(c -> c.toString()).collect(Collectors.toList())); + Collections.sort(v1Value); + Collections.sort(v2Value); + if (!v1Value.equals(v2Value)) { + return false; + } + } + return true; + } + + private Map getColumnTypeFromSelect(SQLite3Select select) { + List fetchColumns = select.getFetchColumns(); + List newFetchColumns = new ArrayList<>(); + for (SQLite3Expression column : fetchColumns) { + newFetchColumns.add(column); + SQLite3Alias columnAlias = (SQLite3Alias) column; + SQLite3Expression typeofColumn = new SQLite3Typeof(columnAlias.getOriginalExpression()); + newFetchColumns.add(typeofColumn); + } + SQLite3Select newSelect = new SQLite3Select(select); + newSelect.setFetchColumns(newFetchColumns); + Map> typeResult = null; + try { + typeResult = getQueryResult(SQLite3Visitor.asString(newSelect), state); + } catch (SQLException e) { + if (errors.errorIsExpected(e.getMessage())) { + throw new IgnoreMeException(); + } else { + throw new AssertionError(e.getMessage()); + } + } + + if (typeResult == null) { + throw new IgnoreMeException(); + } + Map idxTypeMap = new HashMap<>(); + for (int i = 0; i * 2 < typeResult.size(); ++i) { + String columnName = "c" + (i * 2 + 1); + SQLite3Expression t = typeResult.get(columnName).get(0); + SQLite3TextConstant tString = (SQLite3TextConstant) t; + String typeName = tString.asString(); + SQLite3DataType cType = SQLite3DataType.getTypeFromName(typeName); + idxTypeMap.put(i, cType); + } + + return idxTypeMap; + } + + public boolean useSubquery() { + if (this.state.getDbmsSpecificOptions().coddTestModel.isRandom()) { + return Randomly.getBoolean(); + } else if (this.state.getDbmsSpecificOptions().coddTestModel.isExpression()) { + return false; + } else if (this.state.getDbmsSpecificOptions().coddTestModel.isSubquery()) { + return true; + } else { + System.out.printf("Wrong option of --coddtest-model, should be one of: RANDOM, EXPRESSION, SUBQUERY"); + System.exit(1); + return false; + } + } + + public boolean useCorrelatedSubquery() { + return Randomly.getBoolean(); + } + + public boolean enableCommonTableExpression() { + return false; + } + + public boolean enableDerivedTable() { + return true; + } + + public boolean enableInsert() { + return false; + } + + public boolean enableInOperator() { + return false; + } + + @Override + public String getLastQueryString() { + return originalQueryString; + } + + @Override + public Reproducer getLastReproducer() { + return reproducer; + } +} diff --git a/src/sqlancer/sqlite3/oracle/SQLite3Fuzzer.java b/src/sqlancer/sqlite3/oracle/SQLite3Fuzzer.java index 8e868dfd0..e98e5f032 100644 --- a/src/sqlancer/sqlite3/oracle/SQLite3Fuzzer.java +++ b/src/sqlancer/sqlite3/oracle/SQLite3Fuzzer.java @@ -7,7 +7,7 @@ import sqlancer.sqlite3.SQLite3Visitor; // tries to trigger a crash -public class SQLite3Fuzzer implements TestOracle { +public class SQLite3Fuzzer implements TestOracle { private final SQLite3GlobalState globalState; @@ -20,10 +20,8 @@ public void check() throws Exception { String s = SQLite3Visitor .asString(SQLite3RandomQuerySynthesizer.generate(globalState, Randomly.smallNumber() + 1)) + ";"; try { - if (globalState.getDbmsSpecificOptions().executeQuery) { - globalState.executeStatement(new SQLQueryAdapter(s)); - globalState.getManager().incrementSelectQueryCount(); - } + globalState.executeStatement(new SQLQueryAdapter(s)); + globalState.getManager().incrementSelectQueryCount(); } catch (Error e) { } diff --git a/src/sqlancer/sqlite3/oracle/SQLite3NoRECOracle.java b/src/sqlancer/sqlite3/oracle/SQLite3NoRECOracle.java deleted file mode 100644 index db9f55531..000000000 --- a/src/sqlancer/sqlite3/oracle/SQLite3NoRECOracle.java +++ /dev/null @@ -1,158 +0,0 @@ -package sqlancer.sqlite3.oracle; - -import java.sql.SQLException; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import sqlancer.IgnoreMeException; -import sqlancer.Randomly; -import sqlancer.common.oracle.NoRECBase; -import sqlancer.common.oracle.TestOracle; -import sqlancer.common.query.SQLQueryAdapter; -import sqlancer.common.query.SQLancerResultSet; -import sqlancer.sqlite3.SQLite3Errors; -import sqlancer.sqlite3.SQLite3GlobalState; -import sqlancer.sqlite3.SQLite3Visitor; -import sqlancer.sqlite3.ast.SQLite3Aggregate; -import sqlancer.sqlite3.ast.SQLite3Expression; -import sqlancer.sqlite3.ast.SQLite3Expression.Join; -import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3ColumnName; -import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixText; -import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixUnaryOperation; -import sqlancer.sqlite3.ast.SQLite3Expression.SQLite3PostfixUnaryOperation.PostfixUnaryOperator; -import sqlancer.sqlite3.ast.SQLite3Select; -import sqlancer.sqlite3.gen.SQLite3Common; -import sqlancer.sqlite3.gen.SQLite3ExpressionGenerator; -import sqlancer.sqlite3.schema.SQLite3Schema; -import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Column; -import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Table; -import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Tables; - -public class SQLite3NoRECOracle extends NoRECBase implements TestOracle { - - private static final int NO_VALID_RESULT = -1; - private final SQLite3Schema s; - private SQLite3ExpressionGenerator gen; - - public SQLite3NoRECOracle(SQLite3GlobalState globalState) { - super(globalState); - this.s = globalState.getSchema(); - SQLite3Errors.addExpectedExpressionErrors(errors); - SQLite3Errors.addMatchQueryErrors(errors); - SQLite3Errors.addQueryErrors(errors); - errors.add("misuse of aggregate"); - errors.add("misuse of window function"); - errors.add("second argument to nth_value must be a positive integer"); - errors.add("no such table"); - errors.add("no query solution"); - errors.add("unable to use function MATCH in the requested context"); - } - - @Override - public void check() throws SQLException { - SQLite3Tables randomTables = s.getRandomTableNonEmptyTables(); - List columns = randomTables.getColumns(); - gen = new SQLite3ExpressionGenerator(state).setColumns(columns); - SQLite3Expression randomWhereCondition = gen.generateExpression(); - List tables = randomTables.getTables(); - List joinStatements = gen.getRandomJoinClauses(tables); - List tableRefs = SQLite3Common.getTableRefs(tables, s); - SQLite3Select select = new SQLite3Select(); - select.setFromTables(tableRefs); - select.setJoinClauses(joinStatements); - - int optimizedCount = getOptimizedQuery(select, randomWhereCondition); - int unoptimizedCount = getUnoptimizedQuery(select, randomWhereCondition); - if (optimizedCount == NO_VALID_RESULT || unoptimizedCount == NO_VALID_RESULT) { - throw new IgnoreMeException(); - } - if (optimizedCount != unoptimizedCount) { - state.getState().getLocalState().log(optimizedQueryString + ";\n" + unoptimizedQueryString + ";"); - throw new AssertionError(optimizedCount + " " + unoptimizedCount); - } - - } - - private int getUnoptimizedQuery(SQLite3Select select, SQLite3Expression randomWhereCondition) throws SQLException { - SQLite3PostfixUnaryOperation isTrue = new SQLite3PostfixUnaryOperation(PostfixUnaryOperator.IS_TRUE, - randomWhereCondition); - SQLite3PostfixText asText = new SQLite3PostfixText(isTrue, " as count", null); - select.setFetchColumns(Arrays.asList(asText)); - select.setWhereClause(null); - unoptimizedQueryString = "SELECT SUM(count) FROM (" + SQLite3Visitor.asString(select) + ")"; - if (options.logEachSelect()) { - logger.writeCurrent(unoptimizedQueryString); - } - SQLQueryAdapter q = new SQLQueryAdapter(unoptimizedQueryString, errors); - return extractCounts(q); - } - - private int getOptimizedQuery(SQLite3Select select, SQLite3Expression randomWhereCondition) throws SQLException { - boolean useAggregate = Randomly.getBoolean(); - if (Randomly.getBoolean()) { - select.setOrderByExpressions(gen.generateOrderBys()); - } - if (useAggregate) { - select.setFetchColumns(Arrays.asList(new SQLite3Aggregate(Collections.emptyList(), - SQLite3Aggregate.SQLite3AggregateFunction.COUNT_ALL))); - } else { - SQLite3ColumnName aggr = new SQLite3ColumnName(SQLite3Column.createDummy("*"), null); - select.setFetchColumns(Arrays.asList(aggr)); - } - select.setWhereClause(randomWhereCondition); - optimizedQueryString = SQLite3Visitor.asString(select); - if (options.logEachSelect()) { - logger.writeCurrent(optimizedQueryString); - } - SQLQueryAdapter q = new SQLQueryAdapter(optimizedQueryString, errors); - return useAggregate ? extractCounts(q) : countRows(q); - } - - private int countRows(SQLQueryAdapter q) { - int count = 0; - try (SQLancerResultSet rs = q.executeAndGet(state)) { - if (rs == null) { - return NO_VALID_RESULT; - } else { - try { - while (rs.next()) { - count++; - } - } catch (SQLException e) { - count = NO_VALID_RESULT; - } - } - } catch (Exception e) { - if (e instanceof IgnoreMeException) { - throw (IgnoreMeException) e; - } - throw new AssertionError(unoptimizedQueryString, e); - } - return count; - } - - private int extractCounts(SQLQueryAdapter q) { - int count = 0; - try (SQLancerResultSet rs = q.executeAndGet(state)) { - if (rs == null) { - return NO_VALID_RESULT; - } else { - try { - while (rs.next()) { - count += rs.getInt(1); - } - } catch (SQLException e) { - count = NO_VALID_RESULT; - } - } - } catch (Exception e) { - if (e instanceof IgnoreMeException) { - throw (IgnoreMeException) e; - } - throw new AssertionError(unoptimizedQueryString, e); - } - return count; - } - -} diff --git a/src/sqlancer/sqlite3/oracle/SQLite3PivotedQuerySynthesisOracle.java b/src/sqlancer/sqlite3/oracle/SQLite3PivotedQuerySynthesisOracle.java index c19a1059d..98a30dc23 100644 --- a/src/sqlancer/sqlite3/oracle/SQLite3PivotedQuerySynthesisOracle.java +++ b/src/sqlancer/sqlite3/oracle/SQLite3PivotedQuerySynthesisOracle.java @@ -75,7 +75,7 @@ public SQLite3Select getQuery() throws SQLException { .filter(c -> !SQLite3Schema.ROWID_STRINGS.contains(c.getName())).collect(Collectors.toList()); List joinStatements = getJoinStatements(globalState, tables, columnsWithoutRowid); selectStatement.setJoinClauses(joinStatements); - selectStatement.setFromTables(SQLite3Common.getTableRefs(tables, globalState.getSchema())); + selectStatement.setFromList(SQLite3Common.getTableRefs(tables, globalState.getSchema())); fetchColumns = Randomly.nonEmptySubset(columnsWithoutRowid); List allTables = new ArrayList<>(); @@ -99,7 +99,7 @@ public SQLite3Select getQuery() throws SQLException { } /* PQS does not check for ordering, so we can generate any ORDER BY clause */ List orderBy = new SQLite3ExpressionGenerator(globalState).generateOrderBys(); - selectStatement.setOrderByExpressions(orderBy); + selectStatement.setOrderByClauses(orderBy); if (!groupByClause.isEmpty() && Randomly.getBoolean()) { selectStatement.setHavingClause(generateRectifiedExpression(columns, pivotRow, true)); } diff --git a/src/sqlancer/sqlite3/oracle/SQLite3RandomQuerySynthesizer.java b/src/sqlancer/sqlite3/oracle/SQLite3RandomQuerySynthesizer.java index 4447ea2fb..f8f53ae3b 100644 --- a/src/sqlancer/sqlite3/oracle/SQLite3RandomQuerySynthesizer.java +++ b/src/sqlancer/sqlite3/oracle/SQLite3RandomQuerySynthesizer.java @@ -37,10 +37,11 @@ public static SQLite3Expression generate(SQLite3GlobalState globalState, int siz SQLite3Tables targetTables = s.getRandomTableNonEmptyTables(); List expressions = new ArrayList<>(); SQLite3ExpressionGenerator gen = new SQLite3ExpressionGenerator(globalState) - .setColumns(s.getTables().getColumns()); - SQLite3ExpressionGenerator whereClauseGen = new SQLite3ExpressionGenerator(globalState); + .setColumns(targetTables.getColumns()); + SQLite3ExpressionGenerator whereClauseGen = new SQLite3ExpressionGenerator(globalState) + .setColumns(targetTables.getColumns()); SQLite3ExpressionGenerator aggregateGen = new SQLite3ExpressionGenerator(globalState) - .setColumns(s.getTables().getColumns()).allowAggregateFunctions(); + .setColumns(targetTables.getColumns()).allowAggregateFunctions(); // SELECT SQLite3Select select = new SQLite3Select(); @@ -102,7 +103,8 @@ public static SQLite3Expression generate(SQLite3GlobalState globalState, int siz select.setFromList(SQLite3Common.getTableRefs(tables, s)); // TODO: no values are referenced from this sub query yet // if (Randomly.getBooleanWithSmallProbability()) { - // select.getFromList().add(SQLite3RandomQuerySynthesizer.generate(globalState, Randomly.smallNumber() + 1)); + // select.getFromList().add(SQLite3RandomQuerySynthesizer.generate(globalState, + // Randomly.smallNumber() + 1)); // } // WHERE @@ -121,7 +123,7 @@ public static SQLite3Expression generate(SQLite3GlobalState globalState, int siz boolean orderBy = Randomly.getBooleanWithRatherLowProbability(); if (orderBy) { // ORDER BY - select.setOrderByExpressions(gen.generateOrderBys()); + select.setOrderByClauses(gen.generateOrderBys()); } if (Randomly.getBooleanWithRatherLowProbability()) { // LIMIT diff --git a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPAggregateOracle.java b/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPAggregateOracle.java index c0243425f..4d1dd8529 100644 --- a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPAggregateOracle.java +++ b/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPAggregateOracle.java @@ -28,11 +28,12 @@ import sqlancer.sqlite3.schema.SQLite3Schema; import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Tables; -public class SQLite3TLPAggregateOracle implements TestOracle { +public class SQLite3TLPAggregateOracle implements TestOracle { private final SQLite3GlobalState state; private final ExpectedErrors errors = new ExpectedErrors(); private SQLite3ExpressionGenerator gen; + private String generatedQueryString; public SQLite3TLPAggregateOracle(SQLite3GlobalState state) { this.state = state; @@ -53,10 +54,10 @@ public void check() throws SQLException { List from = SQLite3Common.getTableRefs(targetTables.getTables(), s); select.setFromList(from); if (Randomly.getBoolean()) { - select.setOrderByExpressions(gen.generateOrderBys()); + select.setOrderByClauses(gen.generateOrderBys()); } String originalQuery = SQLite3Visitor.asString(select); - + generatedQueryString = originalQuery; SQLite3Expression whereClause = gen.generateExpression(); SQLite3UnaryOperation negatedClause = new SQLite3UnaryOperation(UnaryOperator.NOT, whereClause); SQLite3PostfixUnaryOperation notNullClause = new SQLite3PostfixUnaryOperation(PostfixUnaryOperator.ISNULL, @@ -65,7 +66,9 @@ public void check() throws SQLException { SQLite3Select leftSelect = getSelect(aggregate, from, whereClause); SQLite3Select middleSelect = getSelect(aggregate, from, negatedClause); SQLite3Select rightSelect = getSelect(aggregate, from, notNullClause); - String metamorphicText = "SELECT " + aggregate.getFunc().toString() + "(aggr) FROM ("; + String aggreateMethod = aggregate.getFunc() == SQLite3AggregateFunction.COUNT_ALL + ? SQLite3AggregateFunction.COUNT.toString() : aggregate.getFunc().toString(); + String metamorphicText = "SELECT " + aggreateMethod + "(aggr) FROM ("; metamorphicText += SQLite3Visitor.asString(leftSelect) + " UNION ALL " + SQLite3Visitor.asString(middleSelect) + " UNION ALL " + SQLite3Visitor.asString(rightSelect); metamorphicText += ")"; @@ -117,9 +120,14 @@ private SQLite3Select getSelect(SQLite3Aggregate aggregate, List - implements TestOracle { + implements TestOracle { SQLite3Schema s; SQLite3Tables targetTables; @@ -48,7 +48,7 @@ public void check() throws SQLException { List joinStatements = gen.getRandomJoinClauses(tables); List tableRefs = SQLite3Common.getTableRefs(tables, s); select.setJoinClauses(joinStatements.stream().collect(Collectors.toList())); - select.setFromTables(tableRefs); + select.setFromList(tableRefs); select.setWhereClause(null); } diff --git a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPDistinctOracle.java b/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPDistinctOracle.java index 227db3b6a..80c55c3e2 100644 --- a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPDistinctOracle.java +++ b/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPDistinctOracle.java @@ -11,6 +11,8 @@ public class SQLite3TLPDistinctOracle extends SQLite3TLPBase { + private String generatedQueryString; + public SQLite3TLPDistinctOracle(SQLite3GlobalState state) { super(state); } @@ -21,7 +23,7 @@ public void check() throws SQLException { select.setSelectType(SelectType.DISTINCT); select.setWhereClause(null); String originalQueryString = SQLite3Visitor.asString(select); - + generatedQueryString = originalQueryString; List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); select.setWhereClause(predicate); @@ -37,4 +39,9 @@ public void check() throws SQLException { state); } + @Override + public String getLastQueryString() { + return generatedQueryString; + } + } diff --git a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPGroupByOracle.java b/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPGroupByOracle.java index 91977d973..3b39ef4c8 100644 --- a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPGroupByOracle.java +++ b/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPGroupByOracle.java @@ -14,6 +14,8 @@ public class SQLite3TLPGroupByOracle extends SQLite3TLPBase { + private String generatedQueryString; + public SQLite3TLPGroupByOracle(SQLite3GlobalState state) { super(state); } @@ -24,7 +26,7 @@ public void check() throws SQLException { select.setGroupByClause(select.getFetchColumns()); select.setWhereClause(null); String originalQueryString = SQLite3Visitor.asString(select); - + generatedQueryString = originalQueryString; List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); select.setWhereClause(predicate); @@ -46,4 +48,9 @@ List generateFetchColumns() { .collect(Collectors.toList()); } + @Override + public String getLastQueryString() { + return generatedQueryString; + } + } diff --git a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPHavingOracle.java b/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPHavingOracle.java index 2524d979b..248d4db5f 100644 --- a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPHavingOracle.java +++ b/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPHavingOracle.java @@ -29,10 +29,11 @@ import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Table; import sqlancer.sqlite3.schema.SQLite3Schema.SQLite3Tables; -public class SQLite3TLPHavingOracle implements TestOracle { +public class SQLite3TLPHavingOracle implements TestOracle { private final SQLite3GlobalState state; private final ExpectedErrors errors = new ExpectedErrors(); + private String generatedQueryString; public SQLite3TLPHavingOracle(SQLite3GlobalState state) { this.state = state; @@ -56,12 +57,12 @@ public void check() throws SQLException { List from = SQLite3Common.getTableRefs(tables, state.getSchema()); select.setJoinClauses(joinStatements); select.setSelectType(SelectType.ALL); - select.setFromTables(from); + select.setFromList(from); // TODO order by? select.setGroupByClause(groupByColumns); select.setHavingClause(null); String originalQueryString = SQLite3Visitor.asString(select); - + generatedQueryString = originalQueryString; List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); SQLite3Expression predicate = gen.getHavingClause(); @@ -84,4 +85,9 @@ public void check() throws SQLException { throw new AssertionError(originalQueryString + ";\n" + combinedString + ";"); } } + + @Override + public String getLastQueryString() { + return generatedQueryString; + } } diff --git a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPWhereOracle.java b/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPWhereOracle.java deleted file mode 100644 index 95026de86..000000000 --- a/src/sqlancer/sqlite3/oracle/tlp/SQLite3TLPWhereOracle.java +++ /dev/null @@ -1,43 +0,0 @@ -package sqlancer.sqlite3.oracle.tlp; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.Randomly; -import sqlancer.sqlite3.SQLite3GlobalState; -import sqlancer.sqlite3.SQLite3Visitor; - -public class SQLite3TLPWhereOracle extends SQLite3TLPBase { - - public SQLite3TLPWhereOracle(SQLite3GlobalState state) { - super(state); - } - - @Override - public void check() throws SQLException { - super.check(); - select.setWhereClause(null); - String originalQueryString = SQLite3Visitor.asString(select); - - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - boolean orderBy = Randomly.getBooleanWithSmallProbability(); - if (orderBy) { - select.setOrderByExpressions(gen.generateOrderBys()); - } - select.setWhereClause(predicate); - String firstQueryString = SQLite3Visitor.asString(select); - select.setWhereClause(negatedPredicate); - String secondQueryString = SQLite3Visitor.asString(select); - select.setWhereClause(isNullPredicate); - String thirdQueryString = SQLite3Visitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, !orderBy, state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state); - } - -} diff --git a/src/sqlancer/sqlite3/schema/SQLite3DataType.java b/src/sqlancer/sqlite3/schema/SQLite3DataType.java index 87d91f451..8a343fed2 100644 --- a/src/sqlancer/sqlite3/schema/SQLite3DataType.java +++ b/src/sqlancer/sqlite3/schema/SQLite3DataType.java @@ -1,6 +1,23 @@ package sqlancer.sqlite3.schema; +import sqlancer.IgnoreMeException; + public enum SQLite3DataType { NULL, INT, TEXT, REAL, NONE, BINARY; + public static SQLite3DataType getTypeFromName(String name) { + if (name.equals("integer")) { + return INT; + } else if (name.equals("real")) { + return REAL; + } else if (name.equals("text")) { + return TEXT; + } else if (name.equals("blob")) { + return NONE; + } else if (name.equals("null")) { + return NULL; + } else { + throw new IgnoreMeException(); + } + } } diff --git a/src/sqlancer/sqlite3/schema/SQLite3Schema.java b/src/sqlancer/sqlite3/schema/SQLite3Schema.java index b4bf11002..fc97929d3 100644 --- a/src/sqlancer/sqlite3/schema/SQLite3Schema.java +++ b/src/sqlancer/sqlite3/schema/SQLite3Schema.java @@ -78,6 +78,7 @@ public SQLite3Column(String rowId, SQLite3DataType columnType, boolean isInteger this.generated = generated; } + @Override public boolean isPrimaryKey() { return isPrimaryKey; } @@ -462,4 +463,32 @@ public List getDatabaseTablesWithoutViewsWithoutVirtualTables() { return getDatabaseTables().stream().filter(t -> !t.isView() && !t.isVirtual).collect(Collectors.toList()); } + public String getFreeVirtualTableName() { + int i = 0; + if (Randomly.getBooleanWithRatherLowProbability()) { + i = (int) Randomly.getNotCachedInteger(0, 100); + } + do { + String tableName = String.format("vt%d", i++); + if (getDatabaseTables().stream().noneMatch(t -> t.getName().equalsIgnoreCase(tableName))) { + return tableName; + } + } while (true); + + } + + public String getFreeRtreeTableName() { + int i = 0; + if (Randomly.getBooleanWithRatherLowProbability()) { + i = (int) Randomly.getNotCachedInteger(0, 100); + } + do { + String tableName = String.format("rt%d", i++); + if (getDatabaseTables().stream().noneMatch(t -> t.getName().equalsIgnoreCase(tableName))) { + return tableName; + } + } while (true); + + } + } diff --git a/src/sqlancer/tidb/TiDBBugs.java b/src/sqlancer/tidb/TiDBBugs.java index a4f5cc2dd..3849a3c8c 100644 --- a/src/sqlancer/tidb/TiDBBugs.java +++ b/src/sqlancer/tidb/TiDBBugs.java @@ -2,49 +2,6 @@ // do not make the fields final to avoid warnings public final class TiDBBugs { - - // https://github.com/pingcap/tidb/issues/15987 - public static boolean bug15987; - - // // https://github.com/pingcap/tidb/issues/15988 - public static boolean bug15988; - - // https://github.com/pingcap/tidb/issues/16028 - public static boolean bug16028; - - // https://github.com/pingcap/tidb/issues/16020 - public static boolean bug16020; - - // https://github.com/pingcap/tidb/issues/15990 - public static boolean bug15990; - - // https://github.com/pingcap/tidb/issues/15844 - public static boolean bug15844; - - // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/10 - public static boolean bug10; - - // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/14 - public static boolean bug14; - - // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/15 - public static boolean bug15; - - // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/16 - public static boolean bug16; - - // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/19 - public static boolean bug19; - - // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/48 - public static boolean bug48; - - // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/50 - public static boolean bug50; - - // https://github.com/citusdata/citus/issues/4079 - public static boolean bug4079; - // https://github.com/pingcap/tidb/issues/35677 public static boolean bug35677 = true; @@ -57,6 +14,30 @@ public final class TiDBBugs { // https://github.com/pingcap/tidb/issues/38295 public static boolean bug38295 = true; + // https://github.com/pingcap/tidb/issues/38319 + public static boolean bug38319 = true; + + // https://github.com/pingcap/tidb/issues/44747 + public static boolean bug44747 = true; + + // https://github.com/pingcap/tidb/issues/46556 + public static boolean bug46556 = true; + + // https://github.com/pingcap/tidb/issues/46591 + public static boolean bug46591 = true; + + // https://github.com/pingcap/tidb/issues/46598 + public static boolean bug46598 = true; + + // https://github.com/pingcap/tidb/issues/47346 + public static boolean bug47346 = true; + + // https://github.com/pingcap/tidb/issues/47348 + public static boolean bug47348 = true; + + // https://github.com/pingcap/tidb/issues/51525 + public static boolean bug51525 = true; + private TiDBBugs() { } diff --git a/src/sqlancer/tidb/TiDBErrors.java b/src/sqlancer/tidb/TiDBErrors.java index f37bdb685..528acd0bb 100644 --- a/src/sqlancer/tidb/TiDBErrors.java +++ b/src/sqlancer/tidb/TiDBErrors.java @@ -1,5 +1,8 @@ package sqlancer.tidb; +import java.util.ArrayList; +import java.util.List; + import sqlancer.common.query.ExpectedErrors; public final class TiDBErrors { @@ -7,12 +10,15 @@ public final class TiDBErrors { private TiDBErrors() { } - public static void addExpressionErrors(ExpectedErrors errors) { + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("DECIMAL value is out of range"); errors.add("error parsing regexp"); errors.add("BIGINT UNSIGNED value is out of range"); errors.add("Data truncation: Truncated incorrect time value"); errors.add("Data truncation: Incorrect time value"); + errors.add("Data truncation: Incorrect datetime value"); errors.add("overflows double"); errors.add("overflows bigint"); errors.add("strconv.ParseFloat: parsing"); @@ -23,14 +29,20 @@ public static void addExpressionErrors(ExpectedErrors errors) { errors.add("doesn't have a default value"); // default errors.add("is not valid for CHARACTER SET"); errors.add("DOUBLE value is out of range"); + errors.add("Result of space() was larger than max_allowed_packet"); - errors.add("Data truncation: %s value is out of range in '%s'"); + errors.add("Data truncat"); errors.add("Truncated incorrect FLOAT value"); errors.add("Bad Number"); + errors.add("strconv.Atoi: parsing"); + errors.add("expected integer"); + errors.add("Duplicate entry"); // regex errors.add("error parsing regexp"); errors.add("from regexp"); + errors.add("Empty pattern is invalid"); + errors.add("Invalid regexp pattern"); // To avoid bugs errors.add("Unknown column"); // https://github.com/pingcap/tidb/issues/35522 @@ -49,14 +61,33 @@ public static void addExpressionErrors(ExpectedErrors errors) { if (TiDBBugs.bug38295) { errors.add("assertion failed"); } + if (TiDBBugs.bug44747) { + errors.add("index out of range"); + } + + return errors; } - public static void addExpressionHavingErrors(ExpectedErrors errors) { + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + public static List getExpressionHavingErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("is not in GROUP BY clause and contains nonaggregated column"); errors.add("Unknown column"); + + return errors; } - public static void addInsertErrors(ExpectedErrors errors) { + public static void addExpressionHavingErrors(ExpectedErrors errors) { + errors.addAll(getExpressionHavingErrors()); + } + + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + errors.add("Duplicate entry"); errors.add("cannot be null"); errors.add("doesn't have a default value"); @@ -75,6 +106,17 @@ public static void addInsertErrors(ExpectedErrors errors) { errors.add("Incorrect decimal value"); errors.add("error parsing regexp"); errors.add("is not valid for CHARACTER SET"); + errors.add("for function inet_aton"); + errors.add("'Empty pattern is invalid' from regexp"); + errors.add("Data too long for expression index"); + errors.add("Data too long for column"); + errors.add("Data Too Long"); + + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); } } diff --git a/src/sqlancer/tidb/TiDBExpressionGenerator.java b/src/sqlancer/tidb/TiDBExpressionGenerator.java index 20f48f04a..8eaca35f0 100644 --- a/src/sqlancer/tidb/TiDBExpressionGenerator.java +++ b/src/sqlancer/tidb/TiDBExpressionGenerator.java @@ -3,13 +3,20 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; import sqlancer.IgnoreMeException; import sqlancer.Randomly; +import sqlancer.common.gen.CERTGenerator; +import sqlancer.common.gen.TLPWhereGenerator; import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; import sqlancer.tidb.TiDBProvider.TiDBGlobalState; import sqlancer.tidb.TiDBSchema.TiDBColumn; +import sqlancer.tidb.TiDBSchema.TiDBCompositeDataType; import sqlancer.tidb.TiDBSchema.TiDBDataType; +import sqlancer.tidb.TiDBSchema.TiDBTable; import sqlancer.tidb.ast.TiDBAggregate; import sqlancer.tidb.ast.TiDBAggregate.TiDBAggregateFunction; import sqlancer.tidb.ast.TiDBBinaryBitOperation; @@ -25,21 +32,21 @@ import sqlancer.tidb.ast.TiDBExpression; import sqlancer.tidb.ast.TiDBFunctionCall; import sqlancer.tidb.ast.TiDBFunctionCall.TiDBFunction; +import sqlancer.tidb.ast.TiDBJoin; +import sqlancer.tidb.ast.TiDBJoin.JoinType; import sqlancer.tidb.ast.TiDBOrderingTerm; import sqlancer.tidb.ast.TiDBRegexOperation; import sqlancer.tidb.ast.TiDBRegexOperation.TiDBRegexOperator; +import sqlancer.tidb.ast.TiDBSelect; +import sqlancer.tidb.ast.TiDBTableReference; import sqlancer.tidb.ast.TiDBUnaryPostfixOperation; import sqlancer.tidb.ast.TiDBUnaryPostfixOperation.TiDBUnaryPostfixOperator; import sqlancer.tidb.ast.TiDBUnaryPrefixOperation; import sqlancer.tidb.ast.TiDBUnaryPrefixOperation.TiDBUnaryPrefixOperator; -public class TiDBExpressionGenerator extends UntypedExpressionGenerator { - - private final TiDBGlobalState globalState; - - public TiDBExpressionGenerator(TiDBGlobalState globalState) { - this.globalState = globalState; - } +public class TiDBExpressionGenerator extends UntypedExpressionGenerator + implements TLPWhereGenerator, + CERTGenerator { private enum Gen { UNARY_PREFIX, // @@ -50,79 +57,12 @@ private enum Gen { // BINARY_ARITHMETIC } - @Override - protected TiDBExpression generateExpression(int depth) { - if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { - return generateLeafNode(); - } - if (allowAggregates && Randomly.getBoolean()) { - allowAggregates = false; - TiDBAggregateFunction func = TiDBAggregateFunction.getRandom(); - List args = generateExpressions(func.getNrArgs()); - return new TiDBAggregate(args, func); - } - switch (Randomly.fromOptions(Gen.values())) { - case DEFAULT: - if (TiDBBugs.bug15) { - throw new IgnoreMeException(); - } - if (globalState.getSchema().getDatabaseTables().isEmpty()) { - throw new IgnoreMeException(); - } - return new TiDBFunctionCall(TiDBFunction.DEFAULT, Arrays.asList(generateColumn())); - case UNARY_POSTFIX: - return new TiDBUnaryPostfixOperation(generateExpression(depth + 1), TiDBUnaryPostfixOperator.getRandom()); - case UNARY_PREFIX: - TiDBUnaryPrefixOperator rand = TiDBUnaryPrefixOperator.getRandom(); - return new TiDBUnaryPrefixOperation(generateExpression(depth + 1), rand); - case COLUMN: - return generateColumn(); - case CONSTANT: - return generateConstant(); - case COMPARISON: - return new TiDBBinaryComparisonOperation(generateExpression(depth + 1), generateExpression(depth + 1), - TiDBComparisonOperator.getRandom()); - case REGEX: - return new TiDBRegexOperation(generateExpression(depth + 1), generateExpression(depth + 1), - TiDBRegexOperator.getRandom()); - // case COLLATE: - // return new TiDBCollate(generateExpression(depth + 1), - // Randomly.fromOptions("utf8mb4_bin", "latin1_bin", "binary", "ascii_bin", "utf8_bin")); - case FUNCTION: - TiDBFunction func = TiDBFunction.getRandom(); - return new TiDBFunctionCall(func, generateExpressions(func.getNrArgs(), depth)); - case BINARY_BIT: - return new TiDBBinaryBitOperation(generateExpression(depth + 1), generateExpression(depth + 1), - TiDBBinaryBitOperator.getRandom()); - case BINARY_LOGICAL: - if (TiDBBugs.bug48) { - throw new IgnoreMeException(); - } - return new TiDBBinaryLogicalOperation(generateExpression(depth + 1), generateExpression(depth + 1), - TiDBBinaryLogicalOperator.getRandom()); - // case BINARY_ARITHMETIC: - // return new TiDBBinaryArithmeticOperation(generateExpression(depth + 1), generateExpression(depth + 1), - // TiDBBinaryArithmeticOperator.getRandom()); - case CAST: - return new TiDBCastOperation(generateExpression(depth + 1), Randomly.fromOptions("BINARY", // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/52 - "CHAR", "DATE", "DATETIME", "TIME", // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/13 - "DECIMAL", "SIGNED", "UNSIGNED" /* https://github.com/pingcap/tidb/issues/16028 */)); - case CASE: - if (TiDBBugs.bug19) { - throw new IgnoreMeException(); - } - int nr = Randomly.fromOptions(1, 2); - return new TiDBCase(generateExpression(depth + 1), generateExpressions(nr, depth + 1), - generateExpressions(nr, depth + 1), generateExpression(depth + 1)); - default: - throw new AssertionError(); - } - } + private final TiDBGlobalState globalState; - @Override - protected TiDBExpression generateColumn() { - TiDBColumn column = Randomly.fromList(columns); - return new TiDBColumnReference(column); + private List tables; + + public TiDBExpressionGenerator(TiDBGlobalState globalState) { + this.globalState = globalState; } @Override @@ -199,4 +139,241 @@ public TiDBExpression generateConstant(TiDBDataType type) { } } + @Override + public TiDBExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public TiDBExpression generateBooleanExpression() { + return generateExpression(); + } + + @Override + public TiDBSelect generateSelect() { + return new TiDBSelect(); + } + + @Override + public List getRandomJoinClauses() { + List tableList = tables.stream().map(t -> new TiDBTableReference(t)) + .collect(Collectors.toList()); + List joins = TiDBJoin.getJoins(tableList, globalState); + tables = tableList.stream().map(t -> ((TiDBTableReference) t).getTable()).collect(Collectors.toList()); + return joins; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new TiDBTableReference(t)).collect(Collectors.toList()); + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy && Randomly.getBoolean()) { + return List.of(new TiDBColumnReference( + new TiDBColumn("*", new TiDBCompositeDataType(TiDBDataType.INT), false, false, false))); + } + return Randomly.nonEmptySubset(this.columns).stream().map(c -> new TiDBColumnReference(c)) + .collect(Collectors.toList()); + } + + @Override + protected TiDBExpression generateExpression(int depth) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + return generateLeafNode(); + } + if (allowAggregates && Randomly.getBoolean()) { + allowAggregates = false; + TiDBAggregateFunction func = TiDBAggregateFunction.getRandom(); + List args = generateExpressions(func.getNrArgs()); + return new TiDBAggregate(args, func); + } + switch (Randomly.fromOptions(Gen.values())) { + case DEFAULT: + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + TiDBColumn column = Randomly.fromList(columns); + if (column.hasDefault()) { + return new TiDBFunctionCall(TiDBFunction.DEFAULT, Arrays.asList(new TiDBColumnReference(column))); + } + throw new IgnoreMeException(); + case UNARY_POSTFIX: + return new TiDBUnaryPostfixOperation(generateExpression(depth + 1), TiDBUnaryPostfixOperator.getRandom()); + case UNARY_PREFIX: + TiDBUnaryPrefixOperator rand = TiDBUnaryPrefixOperator.getRandom(); + return new TiDBUnaryPrefixOperation(generateExpression(depth + 1), rand); + case COLUMN: + return generateColumn(); + case CONSTANT: + return generateConstant(); + case COMPARISON: + return new TiDBBinaryComparisonOperation(generateExpression(depth + 1), generateExpression(depth + 1), + TiDBComparisonOperator.getRandom()); + case REGEX: + return new TiDBRegexOperation(generateExpression(depth + 1), generateExpression(depth + 1), + TiDBRegexOperator.getRandom()); + case FUNCTION: + TiDBFunction func = TiDBFunction.getRandom(); + return new TiDBFunctionCall(func, generateExpressions(func.getNrArgs(), depth)); + case BINARY_BIT: + return new TiDBBinaryBitOperation(generateExpression(depth + 1), generateExpression(depth + 1), + TiDBBinaryBitOperator.getRandom()); + case BINARY_LOGICAL: + return new TiDBBinaryLogicalOperation(generateExpression(depth + 1), generateExpression(depth + 1), + TiDBBinaryLogicalOperator.getRandom()); + case CAST: + return new TiDBCastOperation(generateExpression(depth + 1), Randomly.fromOptions("BINARY", // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/52 + "CHAR", "DATE", "DATETIME", "TIME", // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/13 + "DECIMAL", "SIGNED", "UNSIGNED" /* https://github.com/pingcap/tidb/issues/16028 */)); + case CASE: + int nr = Randomly.fromOptions(1, 2); + return new TiDBCase(generateExpression(depth + 1), generateExpressions(nr, depth + 1), + generateExpressions(nr, depth + 1), generateExpression(depth + 1)); + default: + throw new AssertionError(); + } + } + + @Override + protected TiDBExpression generateColumn() { + TiDBColumn column = Randomly.fromList(columns); + return new TiDBColumnReference(column); + } + + @Override + public String generateExplainQuery(TiDBSelect select) { + return "EXPLAIN " + select.asString(); + } + + @Override + public boolean mutate(TiDBSelect select) { + List> mutators = new ArrayList<>(); + + mutators.add(this::mutateJoin); + mutators.add(this::mutateWhere); + if (!TiDBBugs.bug38319) { + mutators.add(this::mutateGroupBy); + mutators.add(this::mutateHaving); + } + mutators.add(this::mutateAnd); + if (!TiDBBugs.bug51525) { + mutators.add(this::mutateOr); + } + mutators.add(this::mutateLimit); + // mutators.add(this::mutateDistinct); + + return Randomly.fromList(mutators).apply(select); + } + + boolean mutateJoin(TiDBSelect select) { + if (select.getJoinList().isEmpty()) { + return false; + } + TiDBJoin join = (TiDBJoin) Randomly.fromList(select.getJoinList()); + if (join.getJoinType() == JoinType.NATURAL) { + return false; + } + + // CROSS does not need ON Condition, while other joins do + // To avoid Null pointer, generating a new new condition when mutating CROSS to + // other joins + if (join.getJoinType() == JoinType.CROSS) { + List columns = new ArrayList<>(); + columns.addAll(((TiDBTableReference) join.getLeftTable()).getTable().getColumns()); + columns.addAll(((TiDBTableReference) join.getRightTable()).getTable().getColumns()); + TiDBExpressionGenerator joinGen2 = new TiDBExpressionGenerator(globalState).setColumns(columns); + join.setOnCondition(joinGen2.generateExpression()); + } + + JoinType newJoinType = TiDBJoin.JoinType.INNER; + if (join.getJoinType() == JoinType.LEFT || join.getJoinType() == JoinType.RIGHT) { // No invarient relation + // between LEFT and RIGHT + // join + newJoinType = JoinType.getRandomExcept(JoinType.NATURAL, JoinType.LEFT, JoinType.RIGHT); + } else { + newJoinType = JoinType.getRandomExcept(JoinType.NATURAL, join.getJoinType()); + } + assert newJoinType != JoinType.NATURAL; // Natural Join is not supported for CERT + boolean increase = join.getJoinType().ordinal() < newJoinType.ordinal(); + join.setJoinType(newJoinType); + if (newJoinType == JoinType.CROSS) { + join.setOnCondition(null); + } + return increase; + } + + boolean mutateWhere(TiDBSelect select) { + boolean increase = select.getWhereClause() != null; + if (increase) { + select.setWhereClause(null); + } else { + select.setWhereClause(generateExpression()); + } + return increase; + } + + boolean mutateHaving(TiDBSelect select) { + if (select.getGroupByExpressions().isEmpty()) { + select.setGroupByExpressions(select.getFetchColumns()); + select.setHavingClause(generateExpression()); + return false; + } else { + if (select.getHavingClause() == null) { + select.setHavingClause(generateExpression()); + return false; + } else { + select.setHavingClause(null); + return true; + } + } + } + + boolean mutateAnd(TiDBSelect select) { + if (select.getWhereClause() == null) { + select.setWhereClause(generateExpression()); + } else { + TiDBExpression newWhere = new TiDBBinaryLogicalOperation(select.getWhereClause(), generateExpression(), + TiDBBinaryLogicalOperator.AND); + select.setWhereClause(newWhere); + } + return false; + } + + boolean mutateOr(TiDBSelect select) { + if (select.getWhereClause() == null) { + select.setWhereClause(generateExpression()); + return false; + } else { + TiDBExpression newWhere = new TiDBBinaryLogicalOperation(select.getWhereClause(), generateExpression(), + TiDBBinaryLogicalOperator.OR); + select.setWhereClause(newWhere); + return true; + } + } + + boolean mutateLimit(TiDBSelect select) { + boolean increase = select.getLimitClause() != null; + if (increase) { + select.setLimitClause(null); + } else { + select.setLimitClause(generateConstant(TiDBDataType.INT)); + } + return increase; + } + + private boolean mutateGroupBy(TiDBSelect select) { + boolean increase = !select.getGroupByExpressions().isEmpty(); + if (increase) { + select.clearGroupByExpressions(); + select.clearHavingClause(); + } else { + select.setGroupByExpressions(select.getFetchColumns()); + } + return increase; + } } diff --git a/src/sqlancer/tidb/TiDBOptions.java b/src/sqlancer/tidb/TiDBOptions.java index f9137f6f7..1619832d6 100644 --- a/src/sqlancer/tidb/TiDBOptions.java +++ b/src/sqlancer/tidb/TiDBOptions.java @@ -1,7 +1,5 @@ package sqlancer.tidb; -import java.sql.SQLException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -9,13 +7,6 @@ import com.beust.jcommander.Parameters; import sqlancer.DBMSSpecificOptions; -import sqlancer.OracleFactory; -import sqlancer.common.oracle.CompositeTestOracle; -import sqlancer.common.oracle.TestOracle; -import sqlancer.tidb.TiDBOptions.TiDBOracleFactory; -import sqlancer.tidb.TiDBProvider.TiDBGlobalState; -import sqlancer.tidb.oracle.TiDBTLPHavingOracle; -import sqlancer.tidb.oracle.TiDBTLPWhereOracle; @Parameters(separators = "=", commandDescription = "TiDB (default port: " + TiDBOptions.DEFAULT_PORT + ", default host: " + TiDBOptions.DEFAULT_HOST + ")") @@ -23,37 +14,23 @@ public class TiDBOptions implements DBMSSpecificOptions { public static final String DEFAULT_HOST = "localhost"; public static final int DEFAULT_PORT = 4000; + @Parameter(names = { "--max-num-tables" }, description = "The maximum number of tables/views that can be created") + public int maxNumTables = 10; + + @Parameter(names = { "--max-num-indexes" }, description = "The maximum number of indexes that can be created") + public int maxNumIndexes = 20; + @Parameter(names = "--oracle") public List oracle = Arrays.asList(TiDBOracleFactory.QUERY_PARTITIONING); - public enum TiDBOracleFactory implements OracleFactory { - HAVING { - @Override - public TestOracle create(TiDBGlobalState globalState) throws SQLException { - return new TiDBTLPHavingOracle(globalState); - } - }, - WHERE { - @Override - public TestOracle create(TiDBGlobalState globalState) throws SQLException { - return new TiDBTLPWhereOracle(globalState); - } - }, - QUERY_PARTITIONING { - @Override - public TestOracle create(TiDBGlobalState globalState) throws SQLException { - List oracles = new ArrayList<>(); - oracles.add(new TiDBTLPWhereOracle(globalState)); - oracles.add(new TiDBTLPHavingOracle(globalState)); - return new CompositeTestOracle(oracles, globalState); - } - }; + @Parameter(names = "--enable-non-prepared-plan-cache") + public boolean nonPreparePlanCache; - } + @Parameter(names = { "--tiflash" }, description = "Enable TiFlash") + public boolean tiflash; @Override public List getTestOracleFactory() { return oracle; } - } diff --git a/src/sqlancer/tidb/TiDBOracleFactory.java b/src/sqlancer/tidb/TiDBOracleFactory.java new file mode 100644 index 000000000..173ff7abd --- /dev/null +++ b/src/sqlancer/tidb/TiDBOracleFactory.java @@ -0,0 +1,77 @@ +package sqlancer.tidb; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CERTOracle; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.tidb.oracle.TiDBDQPOracle; +import sqlancer.tidb.oracle.TiDBTLPHavingOracle; + +public enum TiDBOracleFactory implements OracleFactory { + HAVING { + @Override + public TestOracle create(TiDBProvider.TiDBGlobalState globalState) + throws SQLException { + return new TiDBTLPHavingOracle(globalState); + } + }, + WHERE { + @Override + public TestOracle create(TiDBProvider.TiDBGlobalState globalState) + throws SQLException { + TiDBExpressionGenerator gen = new TiDBExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(TiDBErrors.getExpressionErrors()).build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(TiDBProvider.TiDBGlobalState globalState) + throws Exception { + List> oracles = new ArrayList<>(); + oracles.add(WHERE.create(globalState)); + oracles.add(HAVING.create(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + }, + CERT { + @Override + public TestOracle create(TiDBProvider.TiDBGlobalState globalState) + throws SQLException { + TiDBExpressionGenerator gen = new TiDBExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(TiDBErrors.getExpressionErrors()).build(); + CERTOracle.CheckedFunction> rowCountParser = (rs) -> { + String content = rs.getString(2); + return Optional.of((long) Double.parseDouble(content)); + }; + CERTOracle.CheckedFunction> queryPlanParser = (rs) -> { + String operation = rs.getString(1).split("_")[0]; // Extract operation names for query plans + return Optional.of(operation); + }; + + return new CERTOracle<>(globalState, gen, expectedErrors, rowCountParser, queryPlanParser); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + }, + DQP { + @Override + public TestOracle create(TiDBProvider.TiDBGlobalState globalState) + throws SQLException { + return new TiDBDQPOracle(globalState); + } + }; + +} diff --git a/src/sqlancer/tidb/TiDBProvider.java b/src/sqlancer/tidb/TiDBProvider.java index af41fb143..1d117e234 100644 --- a/src/sqlancer/tidb/TiDBProvider.java +++ b/src/sqlancer/tidb/TiDBProvider.java @@ -1,9 +1,12 @@ package sqlancer.tidb; +import java.io.IOException; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; import java.sql.Statement; +import java.util.List; +import java.util.stream.Collectors; import com.google.auto.service.AutoService; @@ -19,13 +22,16 @@ import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.common.query.SQLQueryProvider; +import sqlancer.common.query.SQLancerResultSet; import sqlancer.tidb.TiDBProvider.TiDBGlobalState; +import sqlancer.tidb.TiDBSchema.TiDBTable; import sqlancer.tidb.gen.TiDBAlterTableGenerator; import sqlancer.tidb.gen.TiDBAnalyzeTableGenerator; import sqlancer.tidb.gen.TiDBDeleteGenerator; +import sqlancer.tidb.gen.TiDBDropTableGenerator; +import sqlancer.tidb.gen.TiDBDropViewGenerator; import sqlancer.tidb.gen.TiDBIndexGenerator; import sqlancer.tidb.gen.TiDBInsertGenerator; -import sqlancer.tidb.gen.TiDBRandomQuerySynthesizer; import sqlancer.tidb.gen.TiDBSetGenerator; import sqlancer.tidb.gen.TiDBTableGenerator; import sqlancer.tidb.gen.TiDBUpdateGenerator; @@ -39,25 +45,20 @@ public TiDBProvider() { } public enum Action implements AbstractAction { - INSERT(TiDBInsertGenerator::getQuery), // - ANALYZE_TABLE(TiDBAnalyzeTableGenerator::getQuery), // - TRUNCATE((g) -> new SQLQueryAdapter("TRUNCATE " + g.getSchema().getRandomTable(t -> !t.isView()).getName())), // - CREATE_INDEX(TiDBIndexGenerator::getQuery), // - DELETE(TiDBDeleteGenerator::getQuery), // - SET(TiDBSetGenerator::getQuery), // - UPDATE(TiDBUpdateGenerator::getQuery), // + CREATE_TABLE(TiDBTableGenerator::createRandomTableStatement), // 0 + CREATE_INDEX(TiDBIndexGenerator::getQuery), // 1 + VIEW_GENERATOR(TiDBViewGenerator::getQuery), // 2 + INSERT(TiDBInsertGenerator::getQuery), // 3 + ALTER_TABLE(TiDBAlterTableGenerator::getQuery), // 4 + TRUNCATE((g) -> new SQLQueryAdapter("TRUNCATE " + g.getSchema().getRandomTable(t -> !t.isView()).getName())), // 5 + UPDATE(TiDBUpdateGenerator::getQuery), // 6 + DELETE(TiDBDeleteGenerator::getQuery), // 7 + SET(TiDBSetGenerator::getQuery), // 8 ADMIN_CHECKSUM_TABLE( - (g) -> new SQLQueryAdapter("ADMIN CHECKSUM TABLE " + g.getSchema().getRandomTable().getName())), // - VIEW_GENERATOR(TiDBViewGenerator::getQuery), // - ALTER_TABLE(TiDBAlterTableGenerator::getQuery), // - EXPLAIN((g) -> { - ExpectedErrors errors = new ExpectedErrors(); - TiDBErrors.addExpressionErrors(errors); - TiDBErrors.addExpressionHavingErrors(errors); - return new SQLQueryAdapter( - "EXPLAIN " + TiDBRandomQuerySynthesizer.generate(g, Randomly.smallNumber() + 1).getQueryString(), - errors); - }); + (g) -> new SQLQueryAdapter("ADMIN CHECKSUM TABLE " + g.getSchema().getRandomTable().getName())), // 9 + ANALYZE_TABLE(TiDBAnalyzeTableGenerator::getQuery), // 10 + DROP_TABLE(TiDBDropTableGenerator::dropTable), // 11 + DROP_VIEW(TiDBDropViewGenerator::dropView); // 12 private final SQLQueryProvider sqlQueryProvider; @@ -87,7 +88,6 @@ private static int mapActions(TiDBGlobalState globalState, Action a) { case CREATE_INDEX: return r.getInteger(0, 2); case INSERT: - case EXPLAIN: return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); case TRUNCATE: case DELETE: @@ -101,6 +101,10 @@ private static int mapActions(TiDBGlobalState globalState, Action a) { return r.getInteger(0, 2); case ALTER_TABLE: return r.getInteger(0, 10); // https://github.com/tidb-challenge-program/bug-hunting-issue/issues/10 + case CREATE_TABLE: + case DROP_TABLE: + case DROP_VIEW: + return 0; default: throw new AssertionError(a); } @@ -133,6 +137,37 @@ public void generateDatabase(TiDBGlobalState globalState) throws Exception { throw new AssertionError(e); } } + + if (globalState.getDbmsSpecificOptions().getTestOracleFactory().stream() + .anyMatch((o) -> o == TiDBOracleFactory.CERT)) { + // Disable strict Group By constraints for ROW oracle + globalState.executeStatement(new SQLQueryAdapter( + "SET @@sql_mode='STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION';")); + + // Enfore statistic collected for all tables + ExpectedErrors errors = new ExpectedErrors(); + TiDBErrors.addExpressionErrors(errors); + for (TiDBTable table : globalState.getSchema().getDatabaseTables()) { + if (!table.isView()) { + globalState.executeStatement(new SQLQueryAdapter("ANALYZE TABLE " + table.getName() + ";", errors)); + } + } + } + + // TiFlash replication settings + if (globalState.getDbmsSpecificOptions().tiflash) { + ExpectedErrors errors = new ExpectedErrors(); + TiDBErrors.addExpressionErrors(errors); + for (TiDBTable table : globalState.getSchema().getDatabaseTables()) { + if (!table.isView()) { + globalState.executeStatement( + new SQLQueryAdapter("ALTER TABLE " + table.getName() + " SET TIFLASH REPLICA 1;", errors)); + } + } + if (Randomly.getBoolean()) { + globalState.executeStatement(new SQLQueryAdapter("set @@tidb_enforce_mpp=1;")); + } + } } @Override @@ -157,6 +192,9 @@ public SQLConnection createDatabase(TiDBGlobalState globalState) throws SQLExcep globalState.getState().logStatement("USE " + databaseName); try (Statement s = con.createStatement()) { s.execute("DROP DATABASE IF EXISTS " + databaseName); + if (globalState.getDbmsSpecificOptions().nonPreparePlanCache) { + s.execute("set global tidb_enable_non_prepared_plan_cache=ON;"); + } } try (Statement s = con.createStatement()) { s.execute(createDatabaseCommand); @@ -172,4 +210,54 @@ public String getDBMSName() { return "tidb"; } + @Override + public String getQueryPlan(String selectStr, TiDBGlobalState globalState) throws Exception { + String queryPlan = ""; + if (globalState.getOptions().logEachSelect()) { + globalState.getLogger().writeCurrent(selectStr); + try { + globalState.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + SQLQueryAdapter q = new SQLQueryAdapter("EXPLAIN FORMAT=brief " + selectStr); + try (SQLancerResultSet rs = q.executeAndGet(globalState)) { + if (rs != null) { + while (rs.next()) { + String targetQueryPlan = rs.getString(1).replace("├─", "").replace("└─", "").replace("│", "").trim() + + ";"; // Unify format + queryPlan += targetQueryPlan; + } + } + } catch (Throwable e) { + e.printStackTrace(); + } + + return queryPlan; + } + + @Override + protected double[] initializeWeightedAverageReward() { + return new double[Action.values().length]; + } + + @Override + protected void executeMutator(int index, TiDBGlobalState globalState) throws Exception { + SQLQueryAdapter queryMutateTable = Action.values()[index].getQuery(globalState); + globalState.executeStatement(queryMutateTable); + } + + @Override + public boolean addRowsToAllTables(TiDBGlobalState globalState) throws Exception { + List tablesNoRow = globalState.getSchema().getDatabaseTables().stream() + .filter(t -> t.getNrRows(globalState) == 0).collect(Collectors.toList()); + for (TiDBTable table : tablesNoRow) { + SQLQueryAdapter queryAddRows = TiDBInsertGenerator.getQuery(globalState, table); + globalState.executeStatement(queryAddRows); + } + return true; + } + } diff --git a/src/sqlancer/tidb/TiDBSchema.java b/src/sqlancer/tidb/TiDBSchema.java index 7caff2974..4ce7306e5 100644 --- a/src/sqlancer/tidb/TiDBSchema.java +++ b/src/sqlancer/tidb/TiDBSchema.java @@ -160,13 +160,17 @@ public static class TiDBColumn extends AbstractTableColumn { @@ -227,37 +235,51 @@ private static TiDBCompositeDataType getColumnType(String typeString) { primitiveType = TiDBDataType.TEXT; break; case "float": + size = 4; + primitiveType = TiDBDataType.FLOATING; + break; case "double": + case "double(8,6)": // workaround to address https://github.com/sqlancer/sqlancer/issues/669 + case "double(23,16)": + size = 8; primitiveType = TiDBDataType.FLOATING; break; case "tinyint(1)": primitiveType = TiDBDataType.BOOL; + size = 1; break; case "null": primitiveType = TiDBDataType.INT; + size = 1; break; + case "tinyint": + case "tinyint(2)": case "tinyint(3)": case "tinyint(4)": primitiveType = TiDBDataType.INT; size = 1; break; + case "smallint": case "smallint(5)": case "smallint(6)": primitiveType = TiDBDataType.INT; size = 2; break; + case "int": case "int(10)": case "int(11)": primitiveType = TiDBDataType.INT; size = 4; break; case "blob": + case "mediumblob": case "longblob": case "tinyblob": primitiveType = TiDBDataType.BLOB; break; case "date": case "datetime": + case "datetime(6)": // workaround to address https://github.com/sqlancer/sqlancer/issues/669 case "timestamp": case "time": case "year": @@ -276,10 +298,6 @@ public TiDBTable(String tableName, List columns, List in super(tableName, columns, indexes, isView); } - public boolean hasPrimaryKey() { - return getColumns().stream().anyMatch(c -> c.isPrimaryKey()); - } - } public static TiDBSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { @@ -292,7 +310,7 @@ public static TiDBSchema fromConnection(SQLConnection con, String databaseName) continue; } List indexes = getIndexes(con, tableName); - boolean isView = tableName.startsWith("v"); + boolean isView = matchesViewName(tableName); TiDBTable t = new TiDBTable(tableName, databaseColumns, indexes, isView); for (TiDBColumn c : databaseColumns) { c.setTable(t); @@ -337,7 +355,9 @@ private static List getTableColumns(SQLConnection con, String tableN String dataType = rs.getString("Type"); boolean isNullable = rs.getString("Null").contentEquals("YES"); boolean isPrimaryKey = rs.getString("Key").contains("PRI"); - TiDBColumn c = new TiDBColumn(columnName, getColumnType(dataType), isPrimaryKey, isNullable); + boolean hasDefault = rs.getString("Default") != null; + TiDBColumn c = new TiDBColumn(columnName, getColumnType(dataType), isPrimaryKey, isNullable, + hasDefault); columns.add(c); } } diff --git a/src/sqlancer/tidb/ast/TiDBAggregate.java b/src/sqlancer/tidb/ast/TiDBAggregate.java index 21bd8eb0c..6650dc395 100644 --- a/src/sqlancer/tidb/ast/TiDBAggregate.java +++ b/src/sqlancer/tidb/ast/TiDBAggregate.java @@ -9,11 +9,7 @@ public class TiDBAggregate extends FunctionNode implements TiDBExpression { public enum TiDBAggregateFunction { - COUNT(1), // - SUM(1), // - AVG(1), // - MIN(1), // - MAX(1); + AVG(1), BIT_AND(1), BIT_OR(1), COUNT(1), SUM(1), MIN(1), MAX(1); private int nrArgs; diff --git a/src/sqlancer/tidb/ast/TiDBExpression.java b/src/sqlancer/tidb/ast/TiDBExpression.java index 49ac9fb70..1f4921836 100644 --- a/src/sqlancer/tidb/ast/TiDBExpression.java +++ b/src/sqlancer/tidb/ast/TiDBExpression.java @@ -1,5 +1,8 @@ package sqlancer.tidb.ast; -public interface TiDBExpression { +import sqlancer.common.ast.newast.Expression; +import sqlancer.tidb.TiDBSchema.TiDBColumn; + +public interface TiDBExpression extends Expression { } diff --git a/src/sqlancer/tidb/ast/TiDBJoin.java b/src/sqlancer/tidb/ast/TiDBJoin.java index e22b34b6a..1e9a30b63 100644 --- a/src/sqlancer/tidb/ast/TiDBJoin.java +++ b/src/sqlancer/tidb/ast/TiDBJoin.java @@ -1,27 +1,36 @@ package sqlancer.tidb.ast; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; import sqlancer.tidb.TiDBExpressionGenerator; import sqlancer.tidb.TiDBProvider.TiDBGlobalState; import sqlancer.tidb.TiDBSchema.TiDBColumn; +import sqlancer.tidb.TiDBSchema.TiDBTable; -public class TiDBJoin implements TiDBExpression { +public class TiDBJoin implements TiDBExpression, Join { private final TiDBExpression leftTable; private final TiDBExpression rightTable; - private final JoinType joinType; - private final TiDBExpression onCondition; + private JoinType joinType; + private TiDBExpression onCondition; private NaturalJoinType outerType; public enum JoinType { - INNER, NATURAL, STRAIGHT, LEFT, RIGHT; + NATURAL, INNER, STRAIGHT, LEFT, RIGHT, CROSS; public static JoinType getRandom() { return Randomly.fromOptions(values()); } + + public static JoinType getRandomExcept(JoinType... exclude) { + JoinType[] values = Arrays.stream(values()).filter(m -> !Arrays.asList(exclude).contains(m)) + .toArray(JoinType[]::new); + return Randomly.fromOptions(values); + } } public enum NaturalJoinType { @@ -52,10 +61,18 @@ public JoinType getJoinType() { return joinType; } + public void setJoinType(JoinType joinType) { + this.joinType = joinType; + } + public TiDBExpression getOnCondition() { return onCondition; } + public static TiDBJoin createCrossJoin(TiDBExpression left, TiDBExpression right, TiDBExpression onClause) { + return new TiDBJoin(left, right, JoinType.CROSS, onClause); + } + public static TiDBJoin createNaturalJoin(TiDBExpression left, TiDBExpression right, NaturalJoinType type) { TiDBJoin tiDBJoin = new TiDBJoin(left, right, JoinType.NATURAL, null); tiDBJoin.setNaturalJoinType(type); @@ -86,8 +103,8 @@ public NaturalJoinType getNaturalJoinType() { return outerType; } - public static List getJoins(List tableList, TiDBGlobalState globalState) { - List joinExpressions = new ArrayList<>(); + public static List getJoins(List tableList, TiDBGlobalState globalState) { + List joinExpressions = new ArrayList<>(); while (tableList.size() >= 2 && Randomly.getBoolean()) { TiDBTableReference leftTable = (TiDBTableReference) tableList.remove(0); TiDBTableReference rightTable = (TiDBTableReference) tableList.remove(0); @@ -110,6 +127,9 @@ public static List getJoins(List tableList, TiDB case RIGHT: joinExpressions.add(TiDBJoin.createRightOuterJoin(leftTable, rightTable, joinGen.generateExpression())); break; + case CROSS: + joinExpressions.add(TiDBJoin.createCrossJoin(leftTable, rightTable, null)); + break; default: throw new AssertionError(); } @@ -117,4 +137,45 @@ public static List getJoins(List tableList, TiDB return joinExpressions; } + public static List getJoinsWithoutNature(List tableList, + TiDBGlobalState globalState) { + List joinExpressions = new ArrayList<>(); + while (tableList.size() >= 2 && Randomly.getBoolean()) { + TiDBTableReference leftTable = (TiDBTableReference) tableList.remove(0); + TiDBTableReference rightTable = (TiDBTableReference) tableList.remove(0); + List columns = new ArrayList<>(leftTable.getTable().getColumns()); + columns.addAll(rightTable.getTable().getColumns()); + TiDBExpressionGenerator joinGen = new TiDBExpressionGenerator(globalState).setColumns(columns); + switch (TiDBJoin.JoinType.getRandom()) { + case INNER: + joinExpressions.add(TiDBJoin.createInnerJoin(leftTable, rightTable, joinGen.generateExpression())); + break; + case STRAIGHT: + joinExpressions.add(TiDBJoin.createStraightJoin(leftTable, rightTable, joinGen.generateExpression())); + break; + case LEFT: + joinExpressions.add(TiDBJoin.createLeftOuterJoin(leftTable, rightTable, joinGen.generateExpression())); + break; + case RIGHT: + joinExpressions.add(TiDBJoin.createRightOuterJoin(leftTable, rightTable, joinGen.generateExpression())); + break; + case NATURAL: + case CROSS: + joinExpressions.add(TiDBJoin.createCrossJoin(leftTable, rightTable, null)); + break; + default: + throw new AssertionError(); + } + } + return joinExpressions; + } + + public void setOnCondition(TiDBExpression generateExpression) { + this.onCondition = generateExpression; + } + + @Override + public void setOnClause(TiDBExpression onClause) { + onCondition = onClause; + } } diff --git a/src/sqlancer/tidb/ast/TiDBSelect.java b/src/sqlancer/tidb/ast/TiDBSelect.java index f812ee8e0..7da05da4c 100644 --- a/src/sqlancer/tidb/ast/TiDBSelect.java +++ b/src/sqlancer/tidb/ast/TiDBSelect.java @@ -1,8 +1,16 @@ package sqlancer.tidb.ast; +import java.util.List; +import java.util.stream.Collectors; + import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.tidb.TiDBSchema.TiDBColumn; +import sqlancer.tidb.TiDBSchema.TiDBTable; +import sqlancer.tidb.visitor.TiDBVisitor; -public class TiDBSelect extends SelectBase implements TiDBExpression { +public class TiDBSelect extends SelectBase + implements TiDBExpression, Select { private TiDBExpression hint; @@ -14,4 +22,20 @@ public TiDBExpression getHint() { return hint; } + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (TiDBExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (TiDBJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return TiDBVisitor.asString(this); + } } diff --git a/src/sqlancer/tidb/gen/TiDBAlterTableGenerator.java b/src/sqlancer/tidb/gen/TiDBAlterTableGenerator.java index 057bb2789..4d0807d4b 100644 --- a/src/sqlancer/tidb/gen/TiDBAlterTableGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBAlterTableGenerator.java @@ -6,9 +6,9 @@ import sqlancer.Randomly; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; -import sqlancer.tidb.TiDBBugs; import sqlancer.tidb.TiDBProvider.TiDBGlobalState; import sqlancer.tidb.TiDBSchema.TiDBColumn; +import sqlancer.tidb.TiDBSchema.TiDBCompositeDataType; import sqlancer.tidb.TiDBSchema.TiDBDataType; import sqlancer.tidb.TiDBSchema.TiDBTable; @@ -25,13 +25,15 @@ public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) { ExpectedErrors errors = new ExpectedErrors(); errors.add( "Information schema is changed during the execution of the statement(for example, table definition may be updated by other DDL ran in parallel)"); - errors.add("Data truncated"); - errors.add("Data truncation"); + errors.add("Data truncat"); errors.add("without a key length"); - errors.add("charset"); - errors.add("not supported"); + errors.add("supported"); errors.add("SQL syntax"); errors.add("can't drop"); + errors.add("A PRIMARY must include all columns in the table's partitioning function"); + errors.add("key was too long"); + errors.add("Duplicate entry"); + errors.add("has a partitioning function dependency and cannot be dropped or renamed"); StringBuilder sb = new StringBuilder("ALTER TABLE "); TiDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); TiDBColumn column = table.getRandomColumn(); @@ -40,14 +42,10 @@ public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) { sb.append(" "); switch (a) { case MODIFY_COLUMN: - if (TiDBBugs.bug10) { - throw new IgnoreMeException(); - } sb.append("MODIFY "); sb.append(column.getName()); sb.append(" "); - sb.append(TiDBDataType.getRandom()); - errors.add("Unsupported modify column"); + sb.append(TiDBCompositeDataType.getRandom().toString()); break; case DROP_COLUMN: sb.append(" DROP "); @@ -94,15 +92,12 @@ public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) { errors.add("'Defining a virtual generated column as primary key' is not supported for generated columns"); break; case CHANGE: - if (TiDBBugs.bug10) { - throw new IgnoreMeException(); - } sb.append(" CHANGE "); sb.append(column.getName()); sb.append(" "); sb.append(column.getName()); sb.append(" "); - sb.append(column.getType().getPrimitiveDataType()); + sb.append(column.getType().toString()); sb.append(" NOT NULL "); errors.add("Invalid use of NULL value"); errors.add("Unsupported modify column:"); diff --git a/src/sqlancer/tidb/gen/TiDBAnalyzeTableGenerator.java b/src/sqlancer/tidb/gen/TiDBAnalyzeTableGenerator.java index d84b117c2..ede857dc5 100644 --- a/src/sqlancer/tidb/gen/TiDBAnalyzeTableGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBAnalyzeTableGenerator.java @@ -1,9 +1,13 @@ package sqlancer.tidb.gen; import java.sql.SQLException; +import java.util.List; import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.TableIndex; +import sqlancer.tidb.TiDBErrors; import sqlancer.tidb.TiDBProvider.TiDBGlobalState; import sqlancer.tidb.TiDBSchema.TiDBTable; @@ -13,24 +17,27 @@ private TiDBAnalyzeTableGenerator() { } public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) throws SQLException { + ExpectedErrors errors = ExpectedErrors.newErrors().with(TiDBErrors.getExpressionErrors()).build(); TiDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); - boolean analyzeIndex = !table.getIndexes().isEmpty() && Randomly.getBoolean(); - StringBuilder sb = new StringBuilder("ANALYZE "); - if (analyzeIndex && Randomly.getBoolean()) { - sb.append("INCREMENTAL "); - } - sb.append("TABLE "); + List indexes = table.getIndexes(); + indexes.removeIf(index -> index.getIndexName().contains("PRIMARY")); + boolean analyzeIndex = !indexes.isEmpty() && Randomly.getBoolean(); + StringBuilder sb = new StringBuilder("ANALYZE TABLE "); sb.append(table.getName()); if (analyzeIndex) { sb.append(" INDEX "); - sb.append(table.getRandomIndex().getIndexName()); + sb.append(Randomly.fromList(indexes).getIndexName()); + } + if (!analyzeIndex && Randomly.getBoolean()) { + sb.append(" ALL COLUMNS"); } if (Randomly.getBoolean()) { sb.append(" WITH "); sb.append(Randomly.getNotCachedInteger(1, 1024)); sb.append(" BUCKETS"); } - return new SQLQueryAdapter(sb.toString()); + errors.add("Fast analyze hasn't reached General Availability and only support analyze version 1 currently"); + return new SQLQueryAdapter(sb.toString(), errors); } } diff --git a/src/sqlancer/tidb/gen/TiDBDeleteGenerator.java b/src/sqlancer/tidb/gen/TiDBDeleteGenerator.java index f6eea4671..c3986f8d9 100644 --- a/src/sqlancer/tidb/gen/TiDBDeleteGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBDeleteGenerator.java @@ -18,7 +18,7 @@ private TiDBDeleteGenerator() { } public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) throws SQLException { - ExpectedErrors errors = new ExpectedErrors(); + ExpectedErrors errors = ExpectedErrors.newErrors().with(TiDBErrors.getExpressionErrors()).build(); TiDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); TiDBExpressionGenerator gen = new TiDBExpressionGenerator(globalState).setColumns(table.getColumns()); StringBuilder sb = new StringBuilder("DELETE "); diff --git a/src/sqlancer/tidb/gen/TiDBDropTableGenerator.java b/src/sqlancer/tidb/gen/TiDBDropTableGenerator.java new file mode 100644 index 000000000..727fffffc --- /dev/null +++ b/src/sqlancer/tidb/gen/TiDBDropTableGenerator.java @@ -0,0 +1,25 @@ +package sqlancer.tidb.gen; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.tidb.TiDBProvider.TiDBGlobalState; + +public final class TiDBDropTableGenerator { + + private TiDBDropTableGenerator() { + } + + public static SQLQueryAdapter dropTable(TiDBGlobalState globalState) { + if (globalState.getSchema().getTables(t -> !t.isView()).size() <= 1) { + throw new IgnoreMeException(); + } + StringBuilder sb = new StringBuilder("DROP TABLE "); + if (Randomly.getBoolean()) { + sb.append("IF EXISTS "); + } + sb.append(globalState.getSchema().getRandomTableOrBailout(t -> !t.isView()).getName()); + return new SQLQueryAdapter(sb.toString(), null, true); + } + +} diff --git a/src/sqlancer/tidb/gen/TiDBDropViewGenerator.java b/src/sqlancer/tidb/gen/TiDBDropViewGenerator.java new file mode 100644 index 000000000..486b5f873 --- /dev/null +++ b/src/sqlancer/tidb/gen/TiDBDropViewGenerator.java @@ -0,0 +1,25 @@ +package sqlancer.tidb.gen; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.tidb.TiDBProvider.TiDBGlobalState; + +public final class TiDBDropViewGenerator { + + private TiDBDropViewGenerator() { + } + + public static SQLQueryAdapter dropView(TiDBGlobalState globalState) { + if (globalState.getSchema().getTables(t -> t.isView()).isEmpty()) { + throw new IgnoreMeException(); + } + StringBuilder sb = new StringBuilder("DROP VIEW "); + if (Randomly.getBoolean()) { + sb.append("IF EXISTS "); + } + sb.append(globalState.getSchema().getRandomTableOrBailout(t -> t.isView()).getName()); + return new SQLQueryAdapter(sb.toString(), null, true); + } + +} diff --git a/src/sqlancer/tidb/gen/TiDBHintGenerator.java b/src/sqlancer/tidb/gen/TiDBHintGenerator.java index f72eafe16..f58695bb8 100644 --- a/src/sqlancer/tidb/gen/TiDBHintGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBHintGenerator.java @@ -1,12 +1,13 @@ package sqlancer.tidb.gen; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.schema.TableIndex; -import sqlancer.tidb.TiDBBugs; import sqlancer.tidb.TiDBSchema.TiDBTable; import sqlancer.tidb.ast.TiDBSelect; import sqlancer.tidb.ast.TiDBText; @@ -23,15 +24,23 @@ enum IndexHint { INL_HASH_JOIN, // INL_MERGE_JOIN, // HASH_JOIN, // + READ_FROM_TIKV, // + READ_FROM_TIFLASH, // HASH_AGG, // STREAM_AGG, // USE_INDEX, // IGNORE_INDEX, // AGG_TO_COP, // - // READ_FROM_STORAGE USE_INDEX_MERGE, // NO_INDEX_MERGE, // - USE_TOJA; + USE_TOJA, // + HASH_JOIN_BUILD, // + HASH_JOIN_PROBE, // + MPP_1PHASE_AGG, // + MPP_2PHASE_AGG, // + LIMIT_TO_COP, // + SHUFFLE_JOIN, // + BROADCAST_JOIN } public TiDBHintGenerator(TiDBSelect select, List tables) { @@ -40,13 +49,40 @@ public TiDBHintGenerator(TiDBSelect select, List tables) { } public static void generateHints(TiDBSelect select, List tables) { - new TiDBHintGenerator(select, tables).generate(); + new TiDBHintGenerator(select, tables).randomHint(); + } + public static List generateAllHints(TiDBSelect select, List tables) { + TiDBHintGenerator generator = new TiDBHintGenerator(select, tables); + return generator.allHints(); } - private void generate() { + private void randomHint() { TiDBTable table = Randomly.fromList(tables); - switch (Randomly.fromOptions(IndexHint.values())) { + IndexHint chosenhint = Randomly.fromOptions(IndexHint.values()); + generate(table, chosenhint); + } + + private List allHints() { + List results = new ArrayList<>(); + IndexHint[] values = IndexHint.values(); + List availableHints = new ArrayList<>(Arrays.asList(values)); + + for (IndexHint hint : availableHints) { + try { + TiDBText generatedHint = generate(Randomly.fromList(tables), hint); + results.add(generatedHint); + } catch (IgnoreMeException e) { + continue; + } + } + return results; + } + + private TiDBText generate(TiDBTable table, IndexHint chosenhint) { + sb.setLength(0); + + switch (chosenhint) { case MERGE_JOIN: tablesHint("MERGE_JOIN"); break; @@ -54,9 +90,6 @@ private void generate() { tablesHint("INL_JOIN"); break; case INL_HASH_JOIN: - if (TiDBBugs.bug50) { - throw new IgnoreMeException(); - } tablesHint("INL_HASH_JOIN"); break; case INL_MERGE_JOIN: @@ -65,12 +98,27 @@ private void generate() { case HASH_JOIN: tablesHint("HASH_JOIN"); break; + case READ_FROM_TIKV: + storageHint("READ_FROM_STORAGE(TIKV"); + break; + case READ_FROM_TIFLASH: + storageHint("READ_FROM_STORAGE(TIFLASH"); + break; case HASH_AGG: sb.append("HASH_AGG()"); break; case STREAM_AGG: sb.append("STREAM_AGG()"); break; + case MPP_1PHASE_AGG: + sb.append("MPP_1PHASE_AGG()"); + break; + case MPP_2PHASE_AGG: + sb.append("MPP_2PHASE_AGG()"); + break; + case LIMIT_TO_COP: + sb.append("LIMIT_TO_COP()"); + break; case USE_INDEX: indexesHint("USE_INDEX"); break; @@ -80,16 +128,18 @@ private void generate() { case AGG_TO_COP: sb.append("AGG_TO_COP()"); break; + case SHUFFLE_JOIN: + twoTablesHint("SHUFFLE_JOIN", table); + break; case USE_INDEX_MERGE: - if (table.hasIndexes()) { - sb.append("USE_INDEX_MERGE("); - sb.append(table.getName()); - sb.append(", "); - List indexes = Randomly.nonEmptySubset(table.getIndexes()); - sb.append(indexes.stream().map(i -> i.getIndexName()).collect(Collectors.joining(", "))); - sb.append(")"); + if (Randomly.getBoolean()) { + if (table.hasIndexes()) { + tablesHint("USE_INDEX_MERGE"); + } else { + throw new IgnoreMeException(); + } } else { - throw new IgnoreMeException(); + twoTablesHint("USE_INDEX_MERGE", table); } break; case NO_INDEX_MERGE: @@ -100,10 +150,21 @@ private void generate() { sb.append(Randomly.getBoolean()); sb.append(")"); break; + case HASH_JOIN_BUILD: + tablesHint("HASH_JOIN_BUILD"); + break; + case HASH_JOIN_PROBE: + tablesHint("HASH_JOIN_PROBE"); + break; + case BROADCAST_JOIN: + twoTablesHint("BROADCAST_JOIN", table); + break; default: throw new AssertionError(); } - select.setHint(new TiDBText(sb.toString())); + TiDBText hint = new TiDBText(sb.toString()); + select.setHint(hint); + return hint; } private void indexesHint(String string) { @@ -129,6 +190,27 @@ private void tablesHint(String string) { sb.append(")"); } + private void storageHint(String string) { + sb.append(string); + sb.append("["); + appendTables(); + sb.append("])"); + } + + private void twoTablesHint(String string, TiDBTable table) { + if (table.hasIndexes()) { + sb.append(string); + sb.append("("); + sb.append(table.getName()); + sb.append(", "); + List indexes = Randomly.nonEmptySubset(table.getIndexes()); + sb.append(indexes.stream().map(i -> i.getIndexName()).collect(Collectors.joining(", "))); + sb.append(")"); + } else { + throw new IgnoreMeException(); + } + } + private void appendTables() { List tableSubset = Randomly.nonEmptySubset(tables); sb.append(tableSubset.stream().map(t -> t.getName()).collect(Collectors.joining(", "))); diff --git a/src/sqlancer/tidb/gen/TiDBIndexGenerator.java b/src/sqlancer/tidb/gen/TiDBIndexGenerator.java index 3444d3281..1be2753b6 100644 --- a/src/sqlancer/tidb/gen/TiDBIndexGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBIndexGenerator.java @@ -3,6 +3,7 @@ import java.sql.SQLException; import java.util.List; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; @@ -16,6 +17,9 @@ private TiDBIndexGenerator() { } public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) throws SQLException { + if (globalState.getSchema().getIndexCount() > globalState.getDbmsSpecificOptions().maxNumIndexes) { + throw new IgnoreMeException(); + } ExpectedErrors errors = new ExpectedErrors(); TiDBTable randomTable = globalState.getSchema().getRandomTable(t -> !t.isView()); @@ -57,6 +61,8 @@ public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) throws SQLEx } errors.add("Cannot decode index value, because"); // invalid value for generated column errors.add("index already exist"); + errors.add("Data truncation"); + errors.add("key was too long"); return new SQLQueryAdapter(sb.toString(), errors, true); } diff --git a/src/sqlancer/tidb/gen/TiDBInsertGenerator.java b/src/sqlancer/tidb/gen/TiDBInsertGenerator.java index a78d6c2fb..363fa438b 100644 --- a/src/sqlancer/tidb/gen/TiDBInsertGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBInsertGenerator.java @@ -26,11 +26,15 @@ public TiDBInsertGenerator(TiDBGlobalState globalState) { } public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) throws SQLException { - return new TiDBInsertGenerator(globalState).get(); + TiDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + return new TiDBInsertGenerator(globalState).get(table); } - private SQLQueryAdapter get() { - TiDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + public static SQLQueryAdapter getQuery(TiDBGlobalState globalState, TiDBTable table) { + return new TiDBInsertGenerator(globalState).get(table); + } + + private SQLQueryAdapter get(TiDBTable table) { gen = new TiDBExpressionGenerator(globalState).setColumns(table.getColumns()); StringBuilder sb = new StringBuilder(); boolean isInsert = Randomly.getBoolean(); @@ -75,18 +79,14 @@ private void insertColumns(StringBuilder sb, List columns) { sb.append(", "); } sb.append("("); - for (int nrColumn = 0; nrColumn < columns.size(); nrColumn++) { - if (nrColumn != 0) { + int i = 0; + for (TiDBColumn c : columns) { + if (i++ != 0) { sb.append(", "); } - insertValue(sb); + sb.append(TiDBVisitor.asString(gen.generateConstant(c.getType().getPrimitiveDataType()))); } sb.append(")"); } } - - private void insertValue(StringBuilder sb) { - sb.append(gen.generateConstant()); // TODO: try to insert valid data - } - } diff --git a/src/sqlancer/tidb/gen/TiDBRandomQuerySynthesizer.java b/src/sqlancer/tidb/gen/TiDBRandomQuerySynthesizer.java index f0609ad56..1b1599ac2 100644 --- a/src/sqlancer/tidb/gen/TiDBRandomQuerySynthesizer.java +++ b/src/sqlancer/tidb/gen/TiDBRandomQuerySynthesizer.java @@ -41,7 +41,7 @@ public static TiDBSelect generateSelect(TiDBGlobalState globalState, int nrColum select.setWhereClause(gen.generateExpression()); } if (Randomly.getBooleanWithRatherLowProbability()) { - select.setOrderByExpressions(gen.generateOrderBys()); + select.setOrderByClauses(gen.generateOrderBys()); } if (Randomly.getBoolean()) { select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); diff --git a/src/sqlancer/tidb/gen/TiDBSetGenerator.java b/src/sqlancer/tidb/gen/TiDBSetGenerator.java index b48ea31cc..c997f6717 100644 --- a/src/sqlancer/tidb/gen/TiDBSetGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBSetGenerator.java @@ -43,9 +43,11 @@ private enum Action { TIDB_ENABLE_WINDOW_FUNCTION("tidb_enable_window_function", (r) -> Randomly.fromOptions(0, 1)), - TIDB_ENABLE_FAST_ANALYZE("tidb_enable_fast_analyze", (r) -> Randomly.fromOptions(0, 1)), // + // TIDB_ENABLE_FAST_ANALYZE("tidb_enable_fast_analyze", (r) -> Randomly.fromOptions(0, 1)), // + // java.sql.SQLException: Fast analyze hasn't reached General Availability and only support analyze version 1 + // currently TIDB_WAIT_SPLIT_REGION_FINISH("tidb_wait_split_region_finish", (r) -> Randomly.fromOptions(0, 1)), - TIDB_SCATTER_REGION("global.tidb_scatter_region", (r) -> Randomly.fromOptions(0, 1)), + TIDB_SCATTER_REGION("global.tidb_scatter_region", (r) -> Randomly.fromOptions("``", "`table`", "global")), TIDB_ENABLE_STMT_SUMMARY("global.tidb_enable_stmt_summary", (r) -> Randomly.fromOptions(0, 1)), // TIDB_ENABLE_CHUNK_RPC("tidb_enable_chunk_rpc", (r) -> Randomly.fromOptions(0, 1)); diff --git a/src/sqlancer/tidb/gen/TiDBTableGenerator.java b/src/sqlancer/tidb/gen/TiDBTableGenerator.java index 25d2cbb99..451681409 100644 --- a/src/sqlancer/tidb/gen/TiDBTableGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBTableGenerator.java @@ -5,10 +5,10 @@ import java.util.List; import java.util.stream.Collectors; +import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; -import sqlancer.tidb.TiDBBugs; import sqlancer.tidb.TiDBExpressionGenerator; import sqlancer.tidb.TiDBProvider.TiDBGlobalState; import sqlancer.tidb.TiDBSchema.TiDBColumn; @@ -24,14 +24,22 @@ public class TiDBTableGenerator { private boolean primaryKeyAsTableConstraints; private final ExpectedErrors errors = new ExpectedErrors(); + public static SQLQueryAdapter createRandomTableStatement(TiDBGlobalState globalState) throws SQLException { + if (globalState.getSchema().getDatabaseTables().size() > globalState.getDbmsSpecificOptions().maxNumTables) { + throw new IgnoreMeException(); + } + return new TiDBTableGenerator().getQuery(globalState); + } + public SQLQueryAdapter getQuery(TiDBGlobalState globalState) throws SQLException { errors.add("Information schema is changed during the execution of the statement"); + errors.add("A CLUSTERED INDEX must include all columns in the table's partitioning function"); String tableName = globalState.getSchema().getFreeTableName(); int nrColumns = Randomly.smallNumber() + 1; allowPrimaryKey = Randomly.getBoolean(); primaryKeyAsTableConstraints = allowPrimaryKey && Randomly.getBoolean(); for (int i = 0; i < nrColumns; i++) { - TiDBColumn fakeColumn = new TiDBColumn("c" + i, null, false, false); + TiDBColumn fakeColumn = new TiDBColumn("c" + i, null, false, false, false); columns.add(fakeColumn); } TiDBExpressionGenerator gen = new TiDBExpressionGenerator(globalState).setColumns(columns); @@ -39,7 +47,7 @@ public SQLQueryAdapter getQuery(TiDBGlobalState globalState) throws SQLException StringBuilder sb = new StringBuilder("CREATE TABLE "); sb.append(tableName); - if (Randomly.getBoolean() && globalState.getSchema().getDatabaseTables().size() > 0) { + if (Randomly.getBoolean() && !globalState.getSchema().getDatabaseTables().isEmpty()) { sb.append(" LIKE "); TiDBTable otherTable = globalState.getSchema().getRandomTable(); sb.append(otherTable.getName()); @@ -113,8 +121,7 @@ && canUseAsUnique(type) && !isGeneratedColumn) { errors.add(" used in key specification without a key length"); } sb.append(")"); - if (Randomly.getBooleanWithRatherLowProbability() - && !TiDBBugs.bug14 /* there are also a number of unresolved other partitioning bugs */) { + if (Randomly.getBooleanWithRatherLowProbability()) { sb.append("PARTITION BY HASH("); sb.append(TiDBVisitor.asString(gen.generateExpression())); sb.append(") "); @@ -127,9 +134,6 @@ && canUseAsUnique(type) && !isGeneratedColumn) { errors.add("A UNIQUE INDEX must include all columns in the table's partitioning function"); errors.add("is of a not allowed type for this type of partitioning"); errors.add("The PARTITION function returns the wrong type"); - if (TiDBBugs.bug16) { - errors.add("UnknownType: *ast.WhenClause"); - } } } @@ -144,11 +148,10 @@ private void appendType(StringBuilder sb, TiDBCompositeDataType type) { } private void appendSizeSpecifiers(StringBuilder sb, TiDBDataType type) { - if (type.isNumeric() && Randomly.getBoolean() && !TiDBBugs.bug16028) { + if (type.isNumeric() && Randomly.getBoolean()) { sb.append(" UNSIGNED"); } - if (type.isNumeric() && Randomly.getBoolean() - && !TiDBBugs.bug16028 /* seems to be the same bug as https://github.com/pingcap/tidb/issues/16028 */) { + if (type.isNumeric() && Randomly.getBoolean()) { sb.append(" ZEROFILL"); } } diff --git a/src/sqlancer/tidb/gen/TiDBUpdateGenerator.java b/src/sqlancer/tidb/gen/TiDBUpdateGenerator.java index 4912899fe..241ee3321 100644 --- a/src/sqlancer/tidb/gen/TiDBUpdateGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBUpdateGenerator.java @@ -4,7 +4,7 @@ import java.util.List; import sqlancer.Randomly; -import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.gen.AbstractUpdateGenerator; import sqlancer.common.query.SQLQueryAdapter; import sqlancer.tidb.TiDBErrors; import sqlancer.tidb.TiDBExpressionGenerator; @@ -13,32 +13,27 @@ import sqlancer.tidb.TiDBSchema.TiDBTable; import sqlancer.tidb.visitor.TiDBVisitor; -public final class TiDBUpdateGenerator { +public final class TiDBUpdateGenerator extends AbstractUpdateGenerator { - private TiDBUpdateGenerator() { + private final TiDBGlobalState globalState; + private TiDBExpressionGenerator gen; + + private TiDBUpdateGenerator(TiDBGlobalState globalState) { + this.globalState = globalState; } public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) throws SQLException { - ExpectedErrors errors = new ExpectedErrors(); + return new TiDBUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() throws SQLException { TiDBTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); - TiDBExpressionGenerator gen = new TiDBExpressionGenerator(globalState).setColumns(table.getColumns()); - StringBuilder sb = new StringBuilder("UPDATE "); + List columns = table.getRandomNonEmptyColumnSubset(); + gen = new TiDBExpressionGenerator(globalState).setColumns(table.getColumns()); + sb.append("UPDATE "); sb.append(table.getName()); sb.append(" SET "); - List columns = table.getRandomNonEmptyColumnSubset(); - for (int i = 0; i < columns.size(); i++) { - if (i != 0) { - sb.append(", "); - } - sb.append(columns.get(i).getName()); - sb.append("="); - if (Randomly.getBoolean()) { - sb.append(gen.generateConstant()); - } else { - sb.append(TiDBVisitor.asString(gen.generateExpression())); - TiDBErrors.addExpressionErrors(errors); - } - } + updateColumns(columns); if (Randomly.getBoolean()) { sb.append(" WHERE "); TiDBErrors.addExpressionErrors(errors); @@ -49,4 +44,14 @@ public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) throws SQLEx return new SQLQueryAdapter(sb.toString(), errors); } + @Override + protected void updateValue(TiDBColumn column) { + if (Randomly.getBoolean()) { + sb.append(gen.generateConstant()); + } else { + sb.append(TiDBVisitor.asString(gen.generateExpression())); + TiDBErrors.addExpressionErrors(errors); + } + } + } diff --git a/src/sqlancer/tidb/gen/TiDBViewGenerator.java b/src/sqlancer/tidb/gen/TiDBViewGenerator.java index 9da8e433e..79c284ecd 100644 --- a/src/sqlancer/tidb/gen/TiDBViewGenerator.java +++ b/src/sqlancer/tidb/gen/TiDBViewGenerator.java @@ -4,8 +4,10 @@ import sqlancer.Randomly; import sqlancer.common.query.ExpectedErrors; import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.tidb.TiDBBugs; import sqlancer.tidb.TiDBErrors; import sqlancer.tidb.TiDBProvider.TiDBGlobalState; +import sqlancer.tidb.ast.TiDBSelect; public final class TiDBViewGenerator { @@ -13,6 +15,9 @@ private TiDBViewGenerator() { } public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) { + if (globalState.getSchema().getDatabaseTables().size() > globalState.getDbmsSpecificOptions().maxNumTables) { + throw new IgnoreMeException(); + } int nrColumns = Randomly.smallNumber() + 1; StringBuilder sb = new StringBuilder("CREATE "); if (Randomly.getBoolean()) { @@ -34,7 +39,11 @@ public static SQLQueryAdapter getQuery(TiDBGlobalState globalState) { sb.append(i); } sb.append(") AS "); - sb.append(TiDBRandomQuerySynthesizer.generate(globalState, nrColumns).getQueryString()); + TiDBSelect select = TiDBRandomQuerySynthesizer.generateSelect(globalState, nrColumns); + if (TiDBBugs.bug38319 && !select.getGroupByExpressions().isEmpty()) { + throw new IgnoreMeException(); + } + sb.append(select.asString()); ExpectedErrors errors = new ExpectedErrors(); TiDBErrors.addExpressionErrors(errors); errors.add( diff --git a/src/sqlancer/tidb/oracle/TiDBDQPOracle.java b/src/sqlancer/tidb/oracle/TiDBDQPOracle.java new file mode 100644 index 000000000..57ee4409f --- /dev/null +++ b/src/sqlancer/tidb/oracle/TiDBDQPOracle.java @@ -0,0 +1,80 @@ +package sqlancer.tidb.oracle; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.tidb.TiDBErrors; +import sqlancer.tidb.TiDBExpressionGenerator; +import sqlancer.tidb.TiDBProvider.TiDBGlobalState; +import sqlancer.tidb.TiDBSchema.TiDBTables; +import sqlancer.tidb.ast.TiDBColumnReference; +import sqlancer.tidb.ast.TiDBExpression; +import sqlancer.tidb.ast.TiDBJoin; +import sqlancer.tidb.ast.TiDBSelect; +import sqlancer.tidb.ast.TiDBTableReference; +import sqlancer.tidb.ast.TiDBText; +import sqlancer.tidb.gen.TiDBHintGenerator; +import sqlancer.tidb.visitor.TiDBVisitor; + +public class TiDBDQPOracle implements TestOracle { + private TiDBExpressionGenerator gen; + private final TiDBGlobalState state; + private TiDBSelect select; + private final ExpectedErrors errors = new ExpectedErrors(); + + public TiDBDQPOracle(TiDBGlobalState globalState) { + state = globalState; + TiDBErrors.addExpressionErrors(errors); + } + + @Override + public void check() throws SQLException { + // Randomly generate a query + TiDBTables tables = state.getSchema().getRandomTableNonEmptyTables(); + gen = new TiDBExpressionGenerator(state).setColumns(tables.getColumns()); + select = new TiDBSelect(); + + List fetchColumns = new ArrayList<>(); + fetchColumns.addAll(Randomly.nonEmptySubset(tables.getColumns()).stream().map(c -> new TiDBColumnReference(c)) + .collect(Collectors.toList())); + select.setFetchColumns(fetchColumns); + + List tableList = tables.getTables().stream().map(t -> new TiDBTableReference(t)) + .collect(Collectors.toList()); + List joins = TiDBJoin.getJoins(tableList, state).stream().collect(Collectors.toList()); + select.setJoinList(joins); + select.setFromList(tableList); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression()); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + if (Randomly.getBoolean()) { + select.setLimitClause(gen.generateExpression()); + } + if (Randomly.getBoolean()) { + select.setOffsetClause(gen.generateExpression()); + } + + String originalQueryString = TiDBVisitor.asString(select); + List originalResult = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, + state); + + List hintList = TiDBHintGenerator.generateAllHints(select, tables.getTables()); + for (TiDBText hint : hintList) { + select.setHint(hint); + String queryString = TiDBVisitor.asString(select); + List result = ComparatorHelper.getResultSetFirstColumnAsString(queryString, errors, state); + ComparatorHelper.assumeResultSetsAreEqual(originalResult, result, originalQueryString, List.of(queryString), + state); + } + } + +} diff --git a/src/sqlancer/tidb/oracle/TiDBTLPBase.java b/src/sqlancer/tidb/oracle/TiDBTLPBase.java index ba2ec9df0..d23e2c4c0 100644 --- a/src/sqlancer/tidb/oracle/TiDBTLPBase.java +++ b/src/sqlancer/tidb/oracle/TiDBTLPBase.java @@ -23,7 +23,7 @@ import sqlancer.tidb.gen.TiDBHintGenerator; public abstract class TiDBTLPBase extends TernaryLogicPartitioningOracleBase - implements TestOracle { + implements TestOracle { TiDBSchema s; TiDBTables targetTables; @@ -50,7 +50,7 @@ public void check() throws SQLException { List tableList = tables.stream().map(t -> new TiDBTableReference(t)) .collect(Collectors.toList()); - List joins = TiDBJoin.getJoins(tableList, state); + List joins = TiDBJoin.getJoins(tableList, state).stream().collect(Collectors.toList()); select.setJoinList(joins); select.setFromList(tableList); select.setWhereClause(null); diff --git a/src/sqlancer/tidb/oracle/TiDBTLPHavingOracle.java b/src/sqlancer/tidb/oracle/TiDBTLPHavingOracle.java index a9f64671b..76e1a3c46 100644 --- a/src/sqlancer/tidb/oracle/TiDBTLPHavingOracle.java +++ b/src/sqlancer/tidb/oracle/TiDBTLPHavingOracle.java @@ -12,7 +12,9 @@ import sqlancer.tidb.ast.TiDBExpression; import sqlancer.tidb.visitor.TiDBVisitor; -public class TiDBTLPHavingOracle extends TiDBTLPBase implements TestOracle { +public class TiDBTLPHavingOracle extends TiDBTLPBase implements TestOracle { + + private String generatedQueryString; public TiDBTLPHavingOracle(TiDBGlobalState state) { super(state); @@ -27,11 +29,12 @@ public void check() throws SQLException { } boolean orderBy = Randomly.getBoolean(); if (orderBy) { - select.setOrderByExpressions(gen.generateOrderBys()); + select.setOrderByClauses(gen.generateOrderBys()); } select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); select.setHavingClause(null); String originalQueryString = TiDBVisitor.asString(select); + generatedQueryString = originalQueryString; List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); select.setHavingClause(predicate); @@ -51,4 +54,9 @@ public void check() throws SQLException { protected TiDBExpression generatePredicate() { return gen.generateHavingClause(); } + + @Override + public String getLastQueryString() { + return generatedQueryString; + } } diff --git a/src/sqlancer/tidb/oracle/TiDBTLPWhereOracle.java b/src/sqlancer/tidb/oracle/TiDBTLPWhereOracle.java deleted file mode 100644 index 4c8496e94..000000000 --- a/src/sqlancer/tidb/oracle/TiDBTLPWhereOracle.java +++ /dev/null @@ -1,45 +0,0 @@ -package sqlancer.tidb.oracle; - -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; - -import sqlancer.ComparatorHelper; -import sqlancer.Randomly; -import sqlancer.tidb.TiDBErrors; -import sqlancer.tidb.TiDBProvider.TiDBGlobalState; -import sqlancer.tidb.visitor.TiDBVisitor; - -public class TiDBTLPWhereOracle extends TiDBTLPBase { - - public TiDBTLPWhereOracle(TiDBGlobalState state) { - super(state); - TiDBErrors.addExpressionErrors(errors); - } - - @Override - public void check() throws SQLException { - super.check(); - select.setWhereClause(null); - String originalQueryString = TiDBVisitor.asString(select); - - List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); - - boolean orderBy = Randomly.getBooleanWithRatherLowProbability(); - if (orderBy) { - select.setOrderByExpressions(gen.generateOrderBys()); - } - select.setWhereClause(predicate); - String firstQueryString = TiDBVisitor.asString(select); - select.setWhereClause(negatedPredicate); - String secondQueryString = TiDBVisitor.asString(select); - select.setWhereClause(isNullPredicate); - String thirdQueryString = TiDBVisitor.asString(select); - List combinedString = new ArrayList<>(); - List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, - thirdQueryString, combinedString, !orderBy, state, errors); - ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, - state); - } - -} diff --git a/src/sqlancer/tidb/visitor/TiDBToStringVisitor.java b/src/sqlancer/tidb/visitor/TiDBToStringVisitor.java index 5e0c56841..67c7cda31 100644 --- a/src/sqlancer/tidb/visitor/TiDBToStringVisitor.java +++ b/src/sqlancer/tidb/visitor/TiDBToStringVisitor.java @@ -1,9 +1,7 @@ package sqlancer.tidb.visitor; -import sqlancer.IgnoreMeException; import sqlancer.Randomly; import sqlancer.common.visitor.ToStringVisitor; -import sqlancer.tidb.TiDBBugs; import sqlancer.tidb.ast.TiDBAggregate; import sqlancer.tidb.ast.TiDBCase; import sqlancer.tidb.ast.TiDBCastOperation; @@ -76,9 +74,9 @@ public void visit(TiDBSelect select) { sb.append(" HAVING "); visit(select.getHavingClause()); } - if (!select.getOrderByExpressions().isEmpty()) { + if (!select.getOrderByClauses().isEmpty()) { sb.append(" ORDER BY "); - visit(select.getOrderByExpressions()); + visit(select.getOrderByClauses()); } } @@ -97,11 +95,7 @@ public void visit(TiDBJoin join) { sb.append(" "); switch (join.getJoinType()) { case INNER: - if (Randomly.getBoolean()) { - sb.append("INNER "); - } else { - sb.append("CROSS "); - } + sb.append("INNER "); sb.append("JOIN "); break; case LEFT: @@ -130,9 +124,6 @@ public void visit(TiDBJoin join) { sb.append("LEFT "); break; case RIGHT: - if (TiDBBugs.bug15844) { - throw new IgnoreMeException(); - } sb.append("RIGHT "); break; default: @@ -140,13 +131,15 @@ public void visit(TiDBJoin join) { } sb.append("JOIN "); break; + case CROSS: + sb.append("CROSS JOIN "); + break; default: throw new AssertionError(); } visit(join.getRightTable()); - sb.append(" "); - if (join.getJoinType() != JoinType.NATURAL) { - sb.append("ON "); + if (join.getOnCondition() != null && join.getJoinType() != JoinType.NATURAL) { + sb.append(" ON "); visit(join.getOnCondition()); } } diff --git a/src/sqlancer/transformations/JSQLParserBasedTransformation.java b/src/sqlancer/transformations/JSQLParserBasedTransformation.java new file mode 100644 index 000000000..e7b7b6145 --- /dev/null +++ b/src/sqlancer/transformations/JSQLParserBasedTransformation.java @@ -0,0 +1,36 @@ +package sqlancer.transformations; + +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.Statement; + +/** + * Transformations based on JSQLParser should be derived from this class. + */ + +public class JSQLParserBasedTransformation extends Transformation { + + protected Statement statement; + + public JSQLParserBasedTransformation(String desc) { + super(desc); + } + + @Override + protected void onStatementChanged() { + if (statementChangedHandler != null) { + statementChangedHandler.accept(this.statement.toString()); + } + } + + @Override + public boolean init(String sql) { + this.current = sql; + try { + statement = CCJSqlParserUtil.parse(current); + } catch (Exception e) { + return false; + } + return true; + } + +} diff --git a/src/sqlancer/transformations/RemoveClausesOfSelect.java b/src/sqlancer/transformations/RemoveClausesOfSelect.java new file mode 100644 index 000000000..832f82762 --- /dev/null +++ b/src/sqlancer/transformations/RemoveClausesOfSelect.java @@ -0,0 +1,102 @@ +package sqlancer.transformations; + +import java.util.List; + +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.statement.select.Distinct; +import net.sf.jsqlparser.statement.select.GroupByElement; +import net.sf.jsqlparser.statement.select.Limit; +import net.sf.jsqlparser.statement.select.Offset; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SubSelect; +import net.sf.jsqlparser.statement.select.WithItem; +import net.sf.jsqlparser.util.deparser.ExpressionDeParser; +import net.sf.jsqlparser.util.deparser.SelectDeParser; + +/** + * remove clauses of a select, such as join, where, group by, distinct, offset, limit. + * + * e.g. select * from t where a = b offset 1 limit 1 -> select * from t; + */ + +public class RemoveClausesOfSelect extends JSQLParserBasedTransformation { + private final SelectDeParser remover = new SelectDeParser() { + @Override + public void visit(PlainSelect plainSelect) { + handleSelect(plainSelect); + super.visit(plainSelect); + } + }; + + public RemoveClausesOfSelect() { + super("remove clauses of select"); + } + + @Override + public boolean init(String original) { + + boolean baseSuc = super.init(original); + if (!baseSuc) { + return false; + } + + this.remover.setExpressionVisitor(new ExpressionDeParser(remover, new StringBuilder())); + return true; + } + + @Override + public void apply() { + super.apply(); + if (statement instanceof Select) { + Select select = (Select) statement; + select.getSelectBody().accept(remover); + + List withItemsList = select.getWithItemsList(); + if (withItemsList == null) { + return; + } + tryRemoveElms(select, withItemsList, Select::setWithItemsList); + + for (WithItem withItem : withItemsList) { + SubSelect subSelect = withItem.getSubSelect(); + if (subSelect == null) { + return; + } + + if (subSelect.getSelectBody() != null) { + subSelect.getSelectBody().accept(remover); + } + } + } + } + + private void handleSelect(PlainSelect plainSelect) { + + Expression where = plainSelect.getWhere(); + if (where != null) { + tryRemove(plainSelect, where, PlainSelect::setWhere); + } + + GroupByElement groupByElement = plainSelect.getGroupBy(); + if (groupByElement != null) { + tryRemove(plainSelect, groupByElement, PlainSelect::setGroupByElement); + } + + Distinct distinct = plainSelect.getDistinct(); + if (distinct != null) { + tryRemove(plainSelect, distinct, PlainSelect::setDistinct); + } + + Offset offset = plainSelect.getOffset(); + if (offset != null) { + tryRemove(plainSelect, offset, PlainSelect::setOffset); + } + + Limit limit = plainSelect.getLimit(); + if (offset != null) { + tryRemove(plainSelect, limit, PlainSelect::setLimit); + } + } + +} diff --git a/src/sqlancer/transformations/RemoveColumnsOfSelect.java b/src/sqlancer/transformations/RemoveColumnsOfSelect.java new file mode 100644 index 000000000..6cdd44d15 --- /dev/null +++ b/src/sqlancer/transformations/RemoveColumnsOfSelect.java @@ -0,0 +1,64 @@ +package sqlancer.transformations; + +import java.util.List; + +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SubSelect; +import net.sf.jsqlparser.statement.select.WithItem; +import net.sf.jsqlparser.util.deparser.ExpressionDeParser; +import net.sf.jsqlparser.util.deparser.SelectDeParser; + +/** + * remove columns of a select: e.g. select a, b, c from t -> select a from t. + */ +public class RemoveColumnsOfSelect extends JSQLParserBasedTransformation { + + private final SelectDeParser remover = new SelectDeParser() { + @Override + public void visit(PlainSelect plainSelect) { + tryRemoveElms(plainSelect, plainSelect.getSelectItems(), PlainSelect::setSelectItems); + super.visit(plainSelect); + } + }; + + public RemoveColumnsOfSelect() { + super("remove columns of a select"); + } + + @Override + public boolean init(String original) { + + boolean baseSucc = super.init(original); + if (!baseSucc) { + return false; + } + this.remover.setExpressionVisitor(new ExpressionDeParser(remover, new StringBuilder())); + return true; + } + + @Override + public void apply() { + super.apply(); + if (statement instanceof Select) { + Select select = (Select) statement; + select.getSelectBody().accept(remover); + + List withItemsList = select.getWithItemsList(); + if (withItemsList == null) { + return; + } + for (WithItem withItem : withItemsList) { + SubSelect subSelect = withItem.getSubSelect(); + if (subSelect == null) { + return; + } + + if (subSelect.getSelectBody() != null) { + subSelect.getSelectBody().accept(remover); + } + } + + } + } +} diff --git a/src/sqlancer/transformations/RemoveElementsOfExpressionList.java b/src/sqlancer/transformations/RemoveElementsOfExpressionList.java new file mode 100644 index 000000000..ebeb0b0af --- /dev/null +++ b/src/sqlancer/transformations/RemoveElementsOfExpressionList.java @@ -0,0 +1,87 @@ +package sqlancer.transformations; + +import java.util.List; + +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.statement.select.GroupByElement; +import net.sf.jsqlparser.statement.select.Join; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.util.deparser.ExpressionDeParser; +import net.sf.jsqlparser.util.deparser.InsertDeParser; +import net.sf.jsqlparser.util.deparser.SelectDeParser; + +/** + * remove elements of an expression list. + * + * NOTE: this only works for select statements and targets at ExpressionList type in JSQLParser, such as groupBy list + */ +public class RemoveElementsOfExpressionList extends JSQLParserBasedTransformation { + private final ExpressionDeParser expressionHandler = new ExpressionDeParser(); + private final SelectDeParser simplifier = new SelectDeParser() { + @Override + public void visit(PlainSelect plainSelect) { + handleSelect(plainSelect); + super.visit(plainSelect); + } + + @Override + public void visit(ExpressionList expressionList) { + List expressions = expressionList.getExpressions(); + tryRemoveElms(expressionList, expressions, ExpressionList::setExpressions); + super.visit(expressionList); + } + }; + private final InsertDeParser insertDeParser = new InsertDeParser() { + @Override + public void visit(ExpressionList expressionList) { + List expressions = expressionList.getExpressions(); + tryRemoveElms(expressionList, expressions, ExpressionList::setExpressions); + super.visit(expressionList); + } + }; + + public RemoveElementsOfExpressionList() { + super("remove elements of expression lists"); + } + + @Override + public boolean init(String sql) { + boolean baseSuc = super.init(sql); + if (!baseSuc) { + return false; + } + this.simplifier.setExpressionVisitor(expressionHandler); + this.expressionHandler.setSelectVisitor(simplifier); + + this.insertDeParser.setExpressionVisitor(expressionHandler); + this.insertDeParser.setSelectVisitor(simplifier); + return true; + } + + @Override + public void apply() { + super.apply(); + if (statement instanceof Select) { + Select select = (Select) statement; + select.getSelectBody().accept(simplifier); + } + } + + private void handleSelect(PlainSelect plainSelect) { + + GroupByElement groupByElement = plainSelect.getGroupBy(); + + if (groupByElement != null && groupByElement.getGroupByExpressionList() != null) { + ExpressionList expressionList = groupByElement.getGroupByExpressionList(); + List list = expressionList.getExpressions(); + tryRemoveElms(expressionList, list, ExpressionList::setExpressions); + } + + List expressionList = plainSelect.getJoins(); + if (expressionList != null) { + tryRemoveElms(plainSelect, expressionList, PlainSelect::setJoins); + } + } +} diff --git a/src/sqlancer/transformations/RemoveRowsOfInsert.java b/src/sqlancer/transformations/RemoveRowsOfInsert.java new file mode 100644 index 000000000..b49f80149 --- /dev/null +++ b/src/sqlancer/transformations/RemoveRowsOfInsert.java @@ -0,0 +1,45 @@ +package sqlancer.transformations; + +import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.expression.operators.relational.ItemsList; +import net.sf.jsqlparser.statement.insert.Insert; +import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SetOperationList; +import net.sf.jsqlparser.statement.values.ValuesStatement; + +/** + * This Transformer remove rows of insert. Given a sql statement: + * + * INSERT INTO t1(c2, c0) VALUES (1508438260, 2929), (1508438260, TIMESTAMP '1969-12-26 01:57:21'), (0.5347171705591047, + * 398662142); -> INSERT INTO t1 (c2, c0) VALUES (0.5347171705591047, 398662142); + */ +public class RemoveRowsOfInsert extends JSQLParserBasedTransformation { + public RemoveRowsOfInsert() { + super("remove rows of an insert statement"); + } + + @Override + public void apply() { + super.apply(); + if (!(statement instanceof Insert)) { + return; + } + SelectBody selectBody = ((Insert) statement).getSelect().getSelectBody(); + if (!(selectBody instanceof SetOperationList)) { + return; + } + SetOperationList insertingList = (SetOperationList) selectBody; + for (SelectBody selBody : insertingList.getSelects()) { + if (!(selBody instanceof ValuesStatement)) { + continue; + } + ValuesStatement valuesStatement = (ValuesStatement) selBody; + ItemsList itemsList = valuesStatement.getExpressions(); + if (!(itemsList instanceof ExpressionList)) { + continue; + } + tryRemoveElms((ExpressionList) itemsList, ((ExpressionList) itemsList).getExpressions(), + ExpressionList::setExpressions); + } + } +} diff --git a/src/sqlancer/transformations/RemoveUnions.java b/src/sqlancer/transformations/RemoveUnions.java new file mode 100644 index 000000000..4f9593345 --- /dev/null +++ b/src/sqlancer/transformations/RemoveUnions.java @@ -0,0 +1,52 @@ +package sqlancer.transformations; + +import java.util.List; + +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SetOperationList; +import net.sf.jsqlparser.util.deparser.ExpressionDeParser; +import net.sf.jsqlparser.util.deparser.SelectDeParser; + +/** + * try removing sub selects of a union statement. + * + * e.g. select 1 union select 2 -> select 1 + */ + +public class RemoveUnions extends JSQLParserBasedTransformation { + + private final SelectDeParser remover = new SelectDeParser() { + @Override + public void visit(SetOperationList list) { + List selectBodyList = list.getSelects(); + tryRemoveElms(list, selectBodyList, SetOperationList::setSelects); + super.visit(list); + } + }; + + public RemoveUnions() { + super("remove union selects"); + } + + @Override + public boolean init(String sql) { + + boolean baseSuc = super.init(sql); + if (!baseSuc) { + return false; + } + + this.remover.setExpressionVisitor(new ExpressionDeParser(remover, new StringBuilder())); + return true; + } + + @Override + public void apply() { + super.apply(); + if (statement instanceof Select) { + Select select = (Select) statement; + select.getSelectBody().accept(remover); + } + } +} diff --git a/src/sqlancer/transformations/RoundDoubleConstant.java b/src/sqlancer/transformations/RoundDoubleConstant.java new file mode 100644 index 000000000..d22496c4e --- /dev/null +++ b/src/sqlancer/transformations/RoundDoubleConstant.java @@ -0,0 +1,74 @@ +package sqlancer.transformations; + +import java.text.DecimalFormat; +import java.util.HashSet; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Round double values which are longer than a certain length. e.g. 2.4782565267 -> 2.478. + * + * This transformation is not based on JSQLParser. + */ +public class RoundDoubleConstant extends Transformation { + + private Set doubleValueCollector; + + private String currentString; + + private static final int ROUND_LENGTH = 3; + private DecimalFormat decimalFormat; + + public RoundDoubleConstant() { + super("round double constant values"); + } + + @Override + public boolean init(String sql) { + super.init(sql); + decimalFormat = new DecimalFormat("#." + "#".repeat(ROUND_LENGTH)); + + currentString = sql; + doubleValueCollector = new HashSet<>(); + + String regex = "\\b-?\\d+\\.\\d+\\b"; + + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(sql); + + while (matcher.find()) { + String matchedText = matcher.group(); + String decimalPart = matchedText.replaceAll("\\d+\\.", ""); + int decimalPlaces = decimalPart.length(); + if (decimalPlaces > ROUND_LENGTH) { + doubleValueCollector.add(matchedText); + } + } + return true; + } + + @Override + public void apply() { + for (String doubleValue : doubleValueCollector) { + + double targetNumber = Double.parseDouble(doubleValue); + String roundedNumberStr = decimalFormat.format(targetNumber); + + String replacement = currentString.replace(doubleValue, roundedNumberStr); + String original = currentString; + + tryReplace(null, original, replacement, (p, r) -> { + currentString = r; + }); + } + super.apply(); + } + + @Override + protected void onStatementChanged() { + if (statementChangedHandler != null) { + statementChangedHandler.accept(currentString); + } + } +} diff --git a/src/sqlancer/transformations/SimplifyConstant.java b/src/sqlancer/transformations/SimplifyConstant.java new file mode 100644 index 000000000..6638dca73 --- /dev/null +++ b/src/sqlancer/transformations/SimplifyConstant.java @@ -0,0 +1,103 @@ +package sqlancer.transformations; + +import java.util.ArrayList; +import java.util.List; + +import net.sf.jsqlparser.expression.DoubleValue; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.StringValue; +import net.sf.jsqlparser.statement.StatementVisitorAdapter; +import net.sf.jsqlparser.statement.insert.Insert; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.util.deparser.ExpressionDeParser; +import net.sf.jsqlparser.util.deparser.SelectDeParser; + +/** + * Shorten the constant of a statement e.g. "a_very_long_str" -> "_", 12341234->1. + * + * Note: The API of JSQLParser may have some problems with double values: `setValue` can't change the literal value of a + * DoubleValue object. Therefore, double values are handled at RoundDoubleConstant class. + */ +public class SimplifyConstant extends JSQLParserBasedTransformation { + static class ConstantCollector extends ExpressionDeParser { + private final List candidates = new ArrayList<>(); + + @Override + public void visit(DoubleValue doubleValue) { + candidates.add(doubleValue); + super.visit(doubleValue); + } + + @Override + public void visit(LongValue longValue) { + candidates.add(longValue); + super.visit(longValue); + } + + @Override + public void visit(StringValue stringValue) { + candidates.add(stringValue); + super.visit(stringValue); + } + + public List getCandidates() { + return candidates; + } + } + + public SimplifyConstant() { + super("simplify constant expressions"); + } + + @Override + public void apply() { + super.apply(); + ConstantCollector collector = new ConstantCollector(); + StringBuilder buffer = new StringBuilder(); + SelectDeParser collectorDeParser = new SelectDeParser(collector, buffer); + collector.setSelectVisitor(collectorDeParser); + collector.setBuffer(buffer); + + List candidates = collector.getCandidates(); + + StatementVisitorAdapter statementVisitor = new StatementVisitorAdapter() { + @Override + public void visit(Insert insert) { + insert.getSelect().getSelectBody().accept(collectorDeParser); + super.visit(insert); + } + + @Override + public void visit(Select select) { + select.getSelectBody().accept(collectorDeParser); + super.visit(select); + } + }; + + statement.accept(statementVisitor); + + for (Expression e : candidates) { + if (e instanceof LongValue) { + simplify((LongValue) e); + } else if (e instanceof StringValue) { + simplify((StringValue) e); + } + } + } + + private void simplify(LongValue longValue) { + long variant = 0; + if (!longValue.getStringValue().equals(String.valueOf(variant))) { + tryReplace(longValue, longValue.getStringValue(), String.valueOf(variant), LongValue::setStringValue); + } + } + + private void simplify(StringValue stringValue) { + String variant = "_"; + if (!stringValue.getValue().equals(variant)) { + tryReplace(stringValue, stringValue.getValue(), variant, StringValue::setValue); + } + } + +} diff --git a/src/sqlancer/transformations/SimplifyExpressions.java b/src/sqlancer/transformations/SimplifyExpressions.java new file mode 100644 index 000000000..244629b4d --- /dev/null +++ b/src/sqlancer/transformations/SimplifyExpressions.java @@ -0,0 +1,98 @@ +package sqlancer.transformations; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; + +import net.sf.jsqlparser.expression.BinaryExpression; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Parenthesis; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.util.deparser.ExpressionDeParser; +import net.sf.jsqlparser.util.deparser.SelectDeParser; + +/** + * This transformation simplifies complicated expressions e.g: a + (b + c) -> b. + */ + +public class SimplifyExpressions extends JSQLParserBasedTransformation { + public SimplifyExpressions() { + super("simplify expressions. e.g. a + b -> a"); + } + + @Override + public boolean init(String sql) { + boolean baseSuc = super.init(sql); + if (!baseSuc) { + return false; + } + this.simplifier.setExpressionVisitor(expressionHandler); + this.expressionHandler.setSelectVisitor(simplifier); + return true; + } + + private final ExpressionDeParser expressionHandler = new ExpressionDeParser() { + @Override + protected void visitBinaryExpression(BinaryExpression binaryExpression, String operator) { + + Expression lhs = binaryExpression.getLeftExpression(); + Expression rhs = binaryExpression.getRightExpression(); + + handleExpression(binaryExpression, lhs, BinaryExpression::setLeftExpression); + handleExpression(binaryExpression, rhs, BinaryExpression::setRightExpression); + + super.visitBinaryExpression(binaryExpression, operator); + } + + }; + private final SelectDeParser simplifier = new SelectDeParser() { + + @Override + public void visit(PlainSelect plainSelect) { + handleSelect(plainSelect); + super.visit(plainSelect); + } + }; + + @Override + public void apply() { + super.apply(); + if (statement instanceof Select) { + Select select = (Select) statement; + select.getSelectBody().accept(simplifier); + } + } + + private void handleSelect(PlainSelect plainSelect) { + Expression where = plainSelect.getWhere(); + if (where != null) { + handleExpression(plainSelect, where, PlainSelect::setWhere); + } + Expression having = plainSelect.getHaving(); + if (having != null) { + handleExpression(plainSelect, having, PlainSelect::setHaving); + } + } + + private List flattenExpression(Expression expression) { + if (expression instanceof BinaryExpression) { + BinaryExpression binaryExpression = (BinaryExpression) expression; + return List.of(binaryExpression.getLeftExpression(), binaryExpression.getRightExpression()); + } else if (expression instanceof Parenthesis) { + return List.of(((Parenthesis) expression).getExpression()); + } + return new ArrayList<>(); + } + + private

void handleExpression(P parent, Expression expr, BiConsumer setter) { + + List expressions = flattenExpression(expr); + for (Expression variant : expressions) { + boolean suc = tryReplace(parent, expr, variant, setter); + if (suc) { + break; + } + } + } +} diff --git a/src/sqlancer/transformations/Transformation.java b/src/sqlancer/transformations/Transformation.java new file mode 100644 index 000000000..affa4c281 --- /dev/null +++ b/src/sqlancer/transformations/Transformation.java @@ -0,0 +1,122 @@ +package sqlancer.transformations; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** + * The base class of transformations. Defines APIs to remove, replace, remove elements of a list. + */ +public class Transformation { + + private static Supplier bugJudgement; + private static long reduceSteps; + + protected boolean isChanged; + protected String current; + protected String desc = ""; + + protected Consumer statementChangedHandler; + + public Transformation(String desc) { + this.desc = desc; + } + + @SuppressWarnings("unused") + protected Transformation() { + } + + public static void setBugJudgement(Supplier judgement) { + bugJudgement = judgement; + } + + @Override + public String toString() { + return desc; + } + + public boolean init(String sql) { + isChanged = false; + return true; + } + + public boolean tryRemove(P parent, T target, BiConsumer setter) { + setter.accept(parent, null); + onStatementChanged(); + if (!bugStillTriggers()) { + setter.accept(parent, target); + onStatementChanged(); + return false; + } + reduceSteps++; + isChanged = true; + return true; + } + + public boolean tryReplace(P parent, T original, T vari, BiConsumer setter) { + setter.accept(parent, vari); + onStatementChanged(); + if (!bugStillTriggers()) { + setter.accept(parent, original); + onStatementChanged(); + return false; + } + reduceSteps++; + isChanged = true; + return true; + } + + public void tryRemoveElms(P parent, List elms, // NOPMD + BiConsumer> setter) { + boolean observeChange; + do { + observeChange = false; + for (int i = elms.size() - 1; i >= 0; i--) { + List reducedElms = new ArrayList<>(elms); + reducedElms.subList(i, i + 1).clear(); + setter.accept(parent, reducedElms); + onStatementChanged(); + if (bugStillTriggers()) { + elms = reducedElms; + onStatementChanged(); + observeChange = true; + } + } + isChanged |= observeChange; + setter.accept(parent, elms); + reduceSteps++; + onStatementChanged(); + } while (observeChange); + + } + + public boolean bugStillTriggers() { + try { + return Transformation.bugJudgement.get(); + } catch (Exception ignored) { + } + return false; + } + + public void apply() { + isChanged = false; + } + + public boolean changed() { + return isChanged; + } + + public static long getReduceSteps() { + return reduceSteps; + } + + protected void onStatementChanged() { + } + + public void setStatementChangedCallBack(Consumer statementChangedHandler) { + this.statementChangedHandler = statementChangedHandler; + } + +} diff --git a/src/sqlancer/yugabyte/YugabyteBugs.java b/src/sqlancer/yugabyte/YugabyteBugs.java new file mode 100644 index 000000000..713ae6f88 --- /dev/null +++ b/src/sqlancer/yugabyte/YugabyteBugs.java @@ -0,0 +1,14 @@ +package sqlancer.yugabyte; + +public final class YugabyteBugs { + + // https://github.com/yugabyte/yugabyte-db/issues/11357 + public static boolean bug11357 = true; + + // https://github.com/yugabyte/yugabyte-db/issues/14330 + public static boolean bug14330 = true; + + private YugabyteBugs() { + } + +} diff --git a/src/sqlancer/yugabyte/ycql/YCQLErrors.java b/src/sqlancer/yugabyte/ycql/YCQLErrors.java new file mode 100644 index 000000000..010b2111a --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/YCQLErrors.java @@ -0,0 +1,32 @@ +package sqlancer.yugabyte.ycql; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class YCQLErrors { + + private YCQLErrors() { + } + + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("Signature mismatch in call to builtin function"); + errors.add("Qualified name not allowed for column reference"); + errors.add("Datatype Mismatch"); + errors.add("Invalid Datatype"); + errors.add("Invalid CQL Statement"); + errors.add("Invalid SQL Statement"); + errors.add("Order by clause contains invalid expression"); + errors.add("Invalid Function Call"); + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/YCQLOptions.java b/src/sqlancer/yugabyte/ycql/YCQLOptions.java new file mode 100644 index 000000000..9d865aca2 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/YCQLOptions.java @@ -0,0 +1,36 @@ +package sqlancer.yugabyte.ycql; + +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; + +@Parameters(separators = "=", commandDescription = "YCQL (default port: " + YCQLOptions.DEFAULT_PORT + + ", default host: " + YCQLOptions.DEFAULT_HOST) +public class YCQLOptions implements DBMSSpecificOptions { + + public static final String DEFAULT_HOST = "localhost"; + public static final int DEFAULT_PORT = 9042; + public static final String DEFAULT_DATACENTER = "datacenter1"; + + @Parameter(names = "--max-num-deletes", description = "The maximum number of DELETE statements that are issued for a database", arity = 1) + public int maxNumDeletes = 1; + + @Parameter(names = "--max-num-updates", description = "The maximum number of UPDATE statements that are issued for a database", arity = 1) + public int maxNumUpdates = 5; + + @Parameter(names = "--datacenter", description = "YCQL datacenter, can be found in system.local table", arity = 1) + public String datacenter = DEFAULT_DATACENTER; + + @Parameter(names = "--oracle") + public List oracles = Arrays.asList(YCQLOracleFactory.FUZZER); + + @Override + public List getTestOracleFactory() { + return oracles; + } + +} diff --git a/src/sqlancer/yugabyte/ycql/YCQLOracleFactory.java b/src/sqlancer/yugabyte/ycql/YCQLOracleFactory.java new file mode 100644 index 000000000..beaf43704 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/YCQLOracleFactory.java @@ -0,0 +1,18 @@ +package sqlancer.yugabyte.ycql; + +import java.sql.SQLException; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.TestOracle; +import sqlancer.yugabyte.ycql.test.YCQLFuzzer; + +public enum YCQLOracleFactory implements OracleFactory { + FUZZER { + @Override + public TestOracle create(YCQLProvider.YCQLGlobalState globalState) + throws SQLException { + return new YCQLFuzzer(globalState); + } + + } +} diff --git a/src/sqlancer/yugabyte/ycql/YCQLProvider.java b/src/sqlancer/yugabyte/ycql/YCQLProvider.java new file mode 100644 index 000000000..b82f5b0a2 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/YCQLProvider.java @@ -0,0 +1,170 @@ +package sqlancer.yugabyte.ycql; + +import static sqlancer.yugabyte.ycql.YCQLSchema.getTableNames; +import static sqlancer.yugabyte.ysql.YSQLProvider.DDL_LOCK; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.gen.YCQLAlterTableGenerator; +import sqlancer.yugabyte.ycql.gen.YCQLDeleteGenerator; +import sqlancer.yugabyte.ycql.gen.YCQLIndexGenerator; +import sqlancer.yugabyte.ycql.gen.YCQLInsertGenerator; +import sqlancer.yugabyte.ycql.gen.YCQLRandomQuerySynthesizer; +import sqlancer.yugabyte.ycql.gen.YCQLTableGenerator; +import sqlancer.yugabyte.ycql.gen.YCQLUpdateGenerator; + +@AutoService(DatabaseProvider.class) +public class YCQLProvider extends SQLProviderAdapter { + + public YCQLProvider() { + super(YCQLGlobalState.class, YCQLOptions.class); + } + + public enum Action implements AbstractAction { + + ALTER(YCQLAlterTableGenerator::getQuery), // + INSERT(YCQLInsertGenerator::getQuery), // + CREATE_INDEX(YCQLIndexGenerator::getQuery), // + DELETE(YCQLDeleteGenerator::generate), // + UPDATE(YCQLUpdateGenerator::getQuery), // + EXPLAIN((g) -> { + ExpectedErrors errors = new ExpectedErrors(); + YCQLErrors.addExpressionErrors(errors); + return new SQLQueryAdapter( + "EXPLAIN " + YCQLToStringVisitor + .asString(YCQLRandomQuerySynthesizer.generateSelect(g, Randomly.smallNumber() + 1)), + errors); + }); + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(YCQLGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + + private static int mapActions(YCQLGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + switch (a) { + case ALTER: + return r.getInteger(0, 10); + case INSERT: + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + case CREATE_INDEX: + case UPDATE: + return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumUpdates + 1); + case EXPLAIN: + return r.getInteger(0, 2); + case DELETE: + return r.getInteger(0, globalState.getDbmsSpecificOptions().maxNumDeletes + 1); + default: + throw new AssertionError(a); + } + } + + public static class YCQLGlobalState extends SQLGlobalState { + + @Override + protected YCQLSchema readSchema() throws SQLException { + return YCQLSchema.fromConnection(getConnection(), getDatabaseName()); + } + + } + + @Override + public void generateDatabase(YCQLGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { + boolean success; + do { + SQLQueryAdapter qt = new YCQLTableGenerator().getQuery(globalState); + success = globalState.executeStatement(qt); + } while (!success); + } + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); // TODO + } + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + YCQLProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public SQLConnection createDatabase(YCQLGlobalState globalState) throws SQLException { + try { + Class.forName("com.ing.data.cassandra.jdbc.CassandraDriver"); + } catch (ClassNotFoundException e) { + throw new AssertionError(); + } + + String host = globalState.getOptions().getHost(); + int port = globalState.getOptions().getPort(); + + if (host == null) { + host = YCQLOptions.DEFAULT_HOST; + } + if (port == MainOptions.NO_SET_PORT) { + port = YCQLOptions.DEFAULT_PORT; + } + + final String url = "jdbc:cassandra://%s:%s/%s?localdatacenter=%s"; + final Connection connection = DriverManager.getConnection( + String.format(url, host, port, "system_schema", globalState.getDbmsSpecificOptions().datacenter)); + + synchronized (DDL_LOCK) { + try (Statement stmt = connection.createStatement()) { + try { + stmt.execute("DROP KEYSPACE IF EXISTS " + globalState.getDatabaseName()); + } catch (Exception se) { + // try again + List tableNames = getTableNames( + new SQLConnection(DriverManager.getConnection(String.format(url, host, port, + globalState.getDatabaseName(), globalState.getDbmsSpecificOptions().datacenter))), + globalState.getDatabaseName()); + for (String tableName : tableNames) { + stmt.execute("DROP TABLE " + globalState.getDatabaseName() + "." + tableName); + } + stmt.execute("DROP KEYSPACE IF EXISTS " + globalState.getDatabaseName()); + } + + stmt.execute("CREATE KEYSPACE IF NOT EXISTS " + globalState.getDatabaseName()); + } + } + + return new SQLConnection(DriverManager.getConnection(String.format(url, host, port, + globalState.getDatabaseName(), globalState.getDbmsSpecificOptions().datacenter))); + } + + @Override + public String getDBMSName() { + return "ycql"; + } + +} diff --git a/src/sqlancer/yugabyte/ycql/YCQLSchema.java b/src/sqlancer/yugabyte/ycql/YCQLSchema.java new file mode 100644 index 000000000..736247364 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/YCQLSchema.java @@ -0,0 +1,263 @@ +package sqlancer.yugabyte.ycql; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.DBMSCommon; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLTable; + +public class YCQLSchema extends AbstractSchema { + + public enum YCQLDataType { + + INT, VARCHAR, BOOLEAN, FLOAT, DATE, TIMESTAMP; + + public static YCQLDataType getRandom() { + return Randomly.fromOptions(values()); + } + + } + + public static class YCQLCompositeDataType { + + private final YCQLDataType dataType; + + private final int size; + + public YCQLCompositeDataType(YCQLDataType dataType, int size) { + this.dataType = dataType; + this.size = size; + } + + public YCQLDataType getPrimitiveDataType() { + return dataType; + } + + public int getSize() { + if (size == -1) { + throw new AssertionError(this); + } + return size; + } + + public static YCQLCompositeDataType getRandom() { + YCQLDataType type = YCQLDataType.getRandom(); + int size = -1; + switch (type) { + case INT: + size = Randomly.fromOptions(1, 2, 4, 8); + break; + case FLOAT: + size = Randomly.fromOptions(4, 8); + break; + case BOOLEAN: + case VARCHAR: + case DATE: + case TIMESTAMP: + size = 0; + break; + default: + throw new AssertionError(type); + } + + return new YCQLCompositeDataType(type, size); + } + + @Override + public String toString() { + switch (getPrimitiveDataType()) { + case INT: + switch (size) { + case 8: + return Randomly.fromOptions("BIGINT"); + case 4: + return Randomly.fromOptions("INTEGER", "INT"); + case 2: + return Randomly.fromOptions("SMALLINT"); + case 1: + return Randomly.fromOptions("TINYINT"); + default: + throw new AssertionError(size); + } + case VARCHAR: + return "VARCHAR"; + case FLOAT: + switch (size) { + case 8: + return Randomly.fromOptions("DOUBLE"); + case 4: + return Randomly.fromOptions("FLOAT"); + default: + throw new AssertionError(size); + } + case BOOLEAN: + return Randomly.fromOptions("BOOLEAN"); + case TIMESTAMP: + return Randomly.fromOptions("TIMESTAMP"); + case DATE: + return Randomly.fromOptions("DATE"); + default: + throw new AssertionError(getPrimitiveDataType()); + } + } + + } + + public static class YCQLColumn extends AbstractTableColumn { + + private final boolean isPrimaryKey; + private final boolean isNullable; + + public YCQLColumn(String name, YCQLCompositeDataType columnType, boolean isPrimaryKey, boolean isNullable) { + super(name, null, columnType); + this.isPrimaryKey = isPrimaryKey; + this.isNullable = isNullable; + } + + @Override + public boolean isPrimaryKey() { + return isPrimaryKey; + } + + public boolean isNullable() { + return isNullable; + } + + } + + public static class YCQLTables extends AbstractTables { + + public YCQLTables(List tables) { + super(tables); + } + + } + + public YCQLSchema(List databaseTables) { + super(databaseTables); + } + + public YCQLTables getRandomTableNonEmptyTables() { + return new YCQLTables(Randomly.nonEmptySubset(getDatabaseTables())); + } + + private static YCQLCompositeDataType getColumnType(String typeString) { + YCQLDataType primitiveType; + int size = -1; + switch (typeString.toUpperCase()) { + case "INT": + case "INTEGER": + primitiveType = YCQLDataType.INT; + size = 4; + break; + case "SMALLINT": + primitiveType = YCQLDataType.INT; + size = 2; + break; + case "BIGINT": + primitiveType = YCQLDataType.INT; + size = 8; + break; + case "TINYINT": + primitiveType = YCQLDataType.INT; + size = 1; + break; + case "VARCHAR": + case "TEXT": + primitiveType = YCQLDataType.VARCHAR; + break; + case "FLOAT": + primitiveType = YCQLDataType.FLOAT; + size = 4; + break; + case "DOUBLE": + primitiveType = YCQLDataType.FLOAT; + size = 8; + break; + case "BOOLEAN": + primitiveType = YCQLDataType.BOOLEAN; + break; + case "DATE": + primitiveType = YCQLDataType.DATE; + break; + case "TIMESTAMP": + primitiveType = YCQLDataType.TIMESTAMP; + break; + default: + throw new AssertionError(); + } + return new YCQLCompositeDataType(primitiveType, size); + } + + public static class YCQLTable extends AbstractRelationalTable { + + public YCQLTable(String tableName, List columns, boolean isView) { + super(tableName, columns, Collections.emptyList(), isView); + } + + } + + public static YCQLSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + List databaseTables = new ArrayList<>(); + List tableNames = getTableNames(con, databaseName); + for (String tableName : tableNames) { + if (DBMSCommon.matchesIndexName(tableName)) { + continue; + } + List databaseColumns = getTableColumns(con, databaseName, tableName); + boolean isView = matchesViewName(tableName); + YCQLTable t = new YCQLTable(tableName, databaseColumns, isView); + for (YCQLColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + + } + return new YCQLSchema(databaseTables); + } + + public static List getTableNames(SQLConnection con, String databaseName) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery( + String.format("select * from system_schema.tables where keyspace_name = '%s'", databaseName))) { + while (rs.next()) { + tableNames.add(rs.getString("table_name")); + } + } + } + return tableNames; + } + + private static List getTableColumns(SQLConnection con, String databaseName, String tableName) + throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String.format( + "select * from system_schema.columns where keyspace_name = '%s' and table_name = '%s'", + databaseName, tableName))) { + while (rs.next()) { + String columnName = rs.getString("column_name"); + String dataType = rs.getString("type"); + boolean isPrimaryKey = rs.getString("kind").contentEquals("partition_key"); + YCQLColumn c = new YCQLColumn(columnName, getColumnType(dataType), isPrimaryKey, true); + columns.add(c); + } + } + } + return columns; + } + +} diff --git a/src/sqlancer/yugabyte/ycql/YCQLToStringVisitor.java b/src/sqlancer/yugabyte/ycql/YCQLToStringVisitor.java new file mode 100644 index 000000000..c9f357fe9 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/YCQLToStringVisitor.java @@ -0,0 +1,63 @@ +package sqlancer.yugabyte.ycql; + +import sqlancer.common.ast.newast.NewToStringVisitor; +import sqlancer.yugabyte.ycql.ast.YCQLConstant; +import sqlancer.yugabyte.ycql.ast.YCQLExpression; +import sqlancer.yugabyte.ycql.ast.YCQLSelect; + +public class YCQLToStringVisitor extends NewToStringVisitor { + + @Override + public void visitSpecific(YCQLExpression expr) { + if (expr instanceof YCQLConstant) { + visit((YCQLConstant) expr); + } else if (expr instanceof YCQLSelect) { + visit((YCQLSelect) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + private void visit(YCQLConstant constant) { + sb.append(constant.toString()); + } + + private void visit(YCQLSelect select) { + sb.append("SELECT "); + if (select.isDistinct()) { + sb.append("DISTINCT "); + } + visit(select.getFetchColumns()); + sb.append(" FROM "); + visit(select.getFromList()); + if (!select.getFromList().isEmpty() && !select.getJoinList().isEmpty()) { + sb.append(", "); + } + if (!select.getJoinList().isEmpty()) { + visit(select.getJoinList()); + } + if (select.getWhereClause() != null) { + sb.append(" WHERE "); + visit(select.getWhereClause()); + } + if (!select.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(select.getOrderByClauses()); + } + if (select.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(select.getLimitClause()); + } + if (select.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(select.getOffsetClause()); + } + } + + public static String asString(YCQLExpression expr) { + YCQLToStringVisitor visitor = new YCQLToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLBetweenOperation.java b/src/sqlancer/yugabyte/ycql/ast/YCQLBetweenOperation.java new file mode 100644 index 000000000..880007cd8 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLBetweenOperation.java @@ -0,0 +1,9 @@ +package sqlancer.yugabyte.ycql.ast; + +import sqlancer.common.ast.newast.NewBetweenOperatorNode; + +public class YCQLBetweenOperation extends NewBetweenOperatorNode implements YCQLExpression { + public YCQLBetweenOperation(YCQLExpression left, YCQLExpression middle, YCQLExpression right, boolean isTrue) { + super(left, middle, right, isTrue); + } +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLBinaryOperation.java b/src/sqlancer/yugabyte/ycql/ast/YCQLBinaryOperation.java new file mode 100644 index 000000000..fac2af681 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLBinaryOperation.java @@ -0,0 +1,10 @@ +package sqlancer.yugabyte.ycql.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class YCQLBinaryOperation extends NewBinaryOperatorNode implements YCQLExpression { + public YCQLBinaryOperation(YCQLExpression left, YCQLExpression right, Operator op) { + super(left, right, op); + } +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLColumnReference.java b/src/sqlancer/yugabyte/ycql/ast/YCQLColumnReference.java new file mode 100644 index 000000000..9c81e70a5 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLColumnReference.java @@ -0,0 +1,12 @@ +package sqlancer.yugabyte.ycql.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.yugabyte.ycql.YCQLSchema; + +public class YCQLColumnReference extends ColumnReferenceNode + implements YCQLExpression { + public YCQLColumnReference(YCQLSchema.YCQLColumn column) { + super(column); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLConstant.java b/src/sqlancer/yugabyte/ycql/ast/YCQLConstant.java new file mode 100644 index 000000000..08f04bcfb --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLConstant.java @@ -0,0 +1,171 @@ +package sqlancer.yugabyte.ycql.ast; + +import java.sql.Timestamp; +import java.text.SimpleDateFormat; + +public class YCQLConstant implements YCQLExpression { + + private YCQLConstant() { + } + + public static class YCQLNullConstant extends YCQLConstant { + + @Override + public String toString() { + return "NULL"; + } + + } + + public static class YCQLIntConstant extends YCQLConstant { + + private final long value; + + public YCQLIntConstant(long value) { + this.value = value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + public long getValue() { + return value; + } + + } + + public static class YCQLDoubleConstant extends YCQLConstant { + + private final double value; + + public YCQLDoubleConstant(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "'+Inf'"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "'-Inf'"; + } + return String.valueOf(value); + } + + } + + public static class YCQLTextConstant extends YCQLConstant { + + private final String value; + + public YCQLTextConstant(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "'" + value.replace("'", "''") + "'"; + } + + } + + public static class YCQLDateConstant extends YCQLConstant { + + public String textRepr; + + public YCQLDateConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("'%s'", textRepr); + } + + } + + public static class YCQLTimestampConstant extends YCQLConstant { + + public String textRepr; + + public YCQLTimestampConstant(long val) { + Timestamp timestamp = new Timestamp(val); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("'%s'", textRepr); + } + + } + + public static class YCQLBooleanConstant extends YCQLConstant { + + private final boolean value; + + public YCQLBooleanConstant(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + + } + + public static YCQLExpression createStringConstant(String text) { + return new YCQLTextConstant(text); + } + + public static YCQLExpression createFloatConstant(double val) { + return new YCQLDoubleConstant(val); + } + + public static YCQLExpression createIntConstant(long val) { + return new YCQLIntConstant(val); + } + + public static YCQLExpression createNullConstant() { + return new YCQLNullConstant(); + } + + public static YCQLExpression createBooleanConstant(boolean val) { + return new YCQLBooleanConstant(val); + } + + public static YCQLExpression createDateConstant(long integer) { + return new YCQLDateConstant(integer); + } + + public static YCQLExpression createTimestampConstant(long integer) { + return new YCQLTimestampConstant(integer); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLExpression.java b/src/sqlancer/yugabyte/ycql/ast/YCQLExpression.java new file mode 100644 index 000000000..4177344eb --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLExpression.java @@ -0,0 +1,8 @@ +package sqlancer.yugabyte.ycql.ast; + +public interface YCQLExpression { + + default YCQLConstant getExpectedValue() { + return null; + } +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLFunction.java b/src/sqlancer/yugabyte/ycql/ast/YCQLFunction.java new file mode 100644 index 000000000..483fab4c4 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLFunction.java @@ -0,0 +1,11 @@ +package sqlancer.yugabyte.ycql.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class YCQLFunction extends NewFunctionNode implements YCQLExpression { + public YCQLFunction(List args, F func) { + super(args, func); + } +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLInOperation.java b/src/sqlancer/yugabyte/ycql/ast/YCQLInOperation.java new file mode 100644 index 000000000..49d43be11 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLInOperation.java @@ -0,0 +1,11 @@ +package sqlancer.yugabyte.ycql.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; + +public class YCQLInOperation extends NewInOperatorNode implements YCQLExpression { + public YCQLInOperation(YCQLExpression left, List right, boolean isNegated) { + super(left, right, isNegated); + } +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLOrderingTerm.java b/src/sqlancer/yugabyte/ycql/ast/YCQLOrderingTerm.java new file mode 100644 index 000000000..5f647e721 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLOrderingTerm.java @@ -0,0 +1,9 @@ +package sqlancer.yugabyte.ycql.ast; + +import sqlancer.common.ast.newast.NewOrderingTerm; + +public class YCQLOrderingTerm extends NewOrderingTerm implements YCQLExpression { + public YCQLOrderingTerm(YCQLExpression expr, Ordering ordering) { + super(expr, ordering); + } +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLSelect.java b/src/sqlancer/yugabyte/ycql/ast/YCQLSelect.java new file mode 100644 index 000000000..25257057e --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLSelect.java @@ -0,0 +1,17 @@ +package sqlancer.yugabyte.ycql.ast; + +import sqlancer.common.ast.SelectBase; + +public class YCQLSelect extends SelectBase implements YCQLExpression { + + private boolean isDistinct; + + public void setDistinct(boolean isDistinct) { + this.isDistinct = isDistinct; + } + + public boolean isDistinct() { + return isDistinct; + } + +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLTableReference.java b/src/sqlancer/yugabyte/ycql/ast/YCQLTableReference.java new file mode 100644 index 000000000..70c01b820 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLTableReference.java @@ -0,0 +1,11 @@ +package sqlancer.yugabyte.ycql.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.yugabyte.ycql.YCQLSchema; + +public class YCQLTableReference extends TableReferenceNode + implements YCQLExpression { + public YCQLTableReference(YCQLSchema.YCQLTable table) { + super(table); + } +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLUnaryPostfixOperation.java b/src/sqlancer/yugabyte/ycql/ast/YCQLUnaryPostfixOperation.java new file mode 100644 index 000000000..b9165211a --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLUnaryPostfixOperation.java @@ -0,0 +1,10 @@ +package sqlancer.yugabyte.ycql.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; + +public class YCQLUnaryPostfixOperation extends NewUnaryPostfixOperatorNode implements YCQLExpression { + public YCQLUnaryPostfixOperation(YCQLExpression expr, BinaryOperatorNode.Operator op) { + super(expr, op); + } +} diff --git a/src/sqlancer/yugabyte/ycql/ast/YCQLUnaryPrefixOperation.java b/src/sqlancer/yugabyte/ycql/ast/YCQLUnaryPrefixOperation.java new file mode 100644 index 000000000..bd3b554ee --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/ast/YCQLUnaryPrefixOperation.java @@ -0,0 +1,10 @@ +package sqlancer.yugabyte.ycql.ast; + +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; + +public class YCQLUnaryPrefixOperation extends NewUnaryPrefixOperatorNode implements YCQLExpression { + public YCQLUnaryPrefixOperation(YCQLExpression expr, BinaryOperatorNode.Operator operator) { + super(expr, operator); + } +} diff --git a/src/sqlancer/yugabyte/ycql/gen/YCQLAlterTableGenerator.java b/src/sqlancer/yugabyte/ycql/gen/YCQLAlterTableGenerator.java new file mode 100644 index 000000000..d15d7fe1c --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/gen/YCQLAlterTableGenerator.java @@ -0,0 +1,48 @@ +package sqlancer.yugabyte.ycql.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLCompositeDataType; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLTable; + +public final class YCQLAlterTableGenerator { + + private YCQLAlterTableGenerator() { + } + + enum Action { + ADD_COLUMN, DROP_COLUMN + } + + public static SQLQueryAdapter getQuery(YCQLGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder("ALTER TABLE "); + YCQLTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + sb.append(" "); + Action action = Randomly.fromOptions(Action.values()); + switch (action) { + case ADD_COLUMN: + sb.append("ADD "); + String columnName = table.getFreeColumnName(); + sb.append(columnName); + sb.append(" "); + sb.append(YCQLCompositeDataType.getRandom().toString()); + break; + case DROP_COLUMN: + sb.append("DROP "); + sb.append(table.getRandomColumn().getName()); + break; + default: + throw new AssertionError(action); + } + + errors.add("Alter key column. Can't alter key column"); + errors.add("cannot remove a key column"); + + return new SQLQueryAdapter(sb.toString(), errors, true); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/gen/YCQLDeleteGenerator.java b/src/sqlancer/yugabyte/ycql/gen/YCQLDeleteGenerator.java new file mode 100644 index 000000000..108cd1be9 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/gen/YCQLDeleteGenerator.java @@ -0,0 +1,31 @@ +package sqlancer.yugabyte.ycql.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ycql.YCQLErrors; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLTable; +import sqlancer.yugabyte.ycql.YCQLToStringVisitor; + +public final class YCQLDeleteGenerator { + + private YCQLDeleteGenerator() { + } + + public static SQLQueryAdapter generate(YCQLGlobalState globalState) { + StringBuilder sb = new StringBuilder("DELETE FROM "); + ExpectedErrors errors = new ExpectedErrors(); + YCQLTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + sb.append(YCQLToStringVisitor.asString( + new YCQLExpressionGenerator(globalState).setColumns(table.getColumns()).generateExpression())); + } + + YCQLErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/gen/YCQLExpressionGenerator.java b/src/sqlancer/yugabyte/ycql/gen/YCQLExpressionGenerator.java new file mode 100644 index 000000000..0b0e4dee5 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/gen/YCQLExpressionGenerator.java @@ -0,0 +1,300 @@ +package sqlancer.yugabyte.ycql.gen; + +import static sqlancer.yugabyte.YugabyteBugs.bug14330; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewOrderingTerm.Ordering; +import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLColumn; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLDataType; +import sqlancer.yugabyte.ycql.ast.YCQLBetweenOperation; +import sqlancer.yugabyte.ycql.ast.YCQLBinaryOperation; +import sqlancer.yugabyte.ycql.ast.YCQLColumnReference; +import sqlancer.yugabyte.ycql.ast.YCQLConstant; +import sqlancer.yugabyte.ycql.ast.YCQLExpression; +import sqlancer.yugabyte.ycql.ast.YCQLFunction; +import sqlancer.yugabyte.ycql.ast.YCQLInOperation; +import sqlancer.yugabyte.ycql.ast.YCQLOrderingTerm; +import sqlancer.yugabyte.ycql.ast.YCQLUnaryPostfixOperation; +import sqlancer.yugabyte.ycql.ast.YCQLUnaryPrefixOperation; + +public final class YCQLExpressionGenerator extends UntypedExpressionGenerator { + + private final YCQLGlobalState globalState; + + public YCQLExpressionGenerator(YCQLGlobalState globalState) { + this.globalState = globalState; + } + + private enum Expression { + BINARY_COMPARISON, BINARY_LOGICAL, BINARY_ARITHMETIC, FUNC, BETWEEN, IN + } + + @Override + protected YCQLExpression generateExpression(int depth) { + if (depth >= globalState.getOptions().getMaxExpressionDepth() || Randomly.getBoolean()) { + return generateLeafNode(); + } + if (allowAggregates && Randomly.getBoolean()) { + YCQLAggregateFunction aggregate = YCQLAggregateFunction.getRandom(); + allowAggregates = false; + return new YCQLFunction<>(generateExpressions(depth + 1, aggregate.getNrArgs()), aggregate); + } + List possibleOptions = new ArrayList<>(Arrays.asList(Expression.values())); + Expression expr = Randomly.fromList(possibleOptions); + switch (expr) { + case BINARY_COMPARISON: + Operator op = YCQLBinaryComparisonOperator.getRandom(); + return new YCQLBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); + case BINARY_LOGICAL: + op = YCQLBinaryLogicalOperator.getRandom(); + return new YCQLBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); + case BINARY_ARITHMETIC: + return new YCQLBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), + YCQLBinaryArithmeticOperator.getRandom()); + case FUNC: + DBFunction func = DBFunction.getRandom(); + return new YCQLFunction(generateExpressions(func.getNrArgs()), func); + case BETWEEN: + return new YCQLBetweenOperation(generateExpression(depth + 1), generateExpression(depth + 1), + generateExpression(depth + 1), Randomly.getBoolean()); + case IN: + return new YCQLInOperation(generateExpression(depth + 1), + generateExpressions(depth + 1, Randomly.smallNumber() + 1), Randomly.getBoolean()); + default: + throw new AssertionError(expr); + } + } + + @Override + protected YCQLExpression generateColumn() { + YCQLColumn column = Randomly.fromList(columns); + return new YCQLColumnReference(column); + } + + @Override + public YCQLExpression generateConstant() { + if (Randomly.getBooleanWithSmallProbability()) { + if (bug14330) { + throw new IgnoreMeException(); + } + + return YCQLConstant.createNullConstant(); + } + YCQLDataType type = YCQLDataType.getRandom(); + switch (type) { + case INT: + return YCQLConstant.createIntConstant(globalState.getRandomly().getInteger()); + case DATE: + return YCQLConstant.createDateConstant(globalState.getRandomly().getInteger()); + case TIMESTAMP: + return YCQLConstant.createTimestampConstant(globalState.getRandomly().getInteger()); + case VARCHAR: + return YCQLConstant.createStringConstant(globalState.getRandomly().getString()); + case BOOLEAN: + return YCQLConstant.createBooleanConstant(Randomly.getBoolean()); + case FLOAT: + return YCQLConstant.createFloatConstant(globalState.getRandomly().getDouble()); + default: + throw new AssertionError(); + } + } + + @Override + public List generateOrderBys() { + List expr = super.generateOrderBys(); + List newExpr = new ArrayList<>(expr.size()); + for (YCQLExpression curExpr : expr) { + if (Randomly.getBoolean()) { + curExpr = new YCQLOrderingTerm(curExpr, Ordering.getRandom()); + } + newExpr.add(curExpr); + } + return newExpr; + }; + + public enum YCQLAggregateFunction { + MAX(1), MIN(1), AVG(1), COUNT(1), SUM(1); + + private final int nrArgs; + + YCQLAggregateFunction(int nrArgs) { + this.nrArgs = nrArgs; + } + + public static YCQLAggregateFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + return nrArgs; + } + + } + + public enum DBFunction { + // YCQL functions + BLOB(1), // + TIMEUUID(1), // + DATE(0), // + TIME(0), // + TIMESTAMP(0), // + BIGINT(1), // + UUID(0); // + // // extras + // PARTITION_HASH(2), // + // WRITETIME(1), // + // TTL(1); // + + private final int nrArgs; + private final boolean isVariadic; + + DBFunction(int nrArgs) { + this(nrArgs, false); + } + + DBFunction(int nrArgs, boolean isVariadic) { + this.nrArgs = nrArgs; + this.isVariadic = isVariadic; + } + + public static DBFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + if (isVariadic) { + return Randomly.smallNumber() + nrArgs; + } else { + return nrArgs; + } + } + + } + + public enum YCQLUnaryPostfixOperator implements Operator { + + IS_NULL("IS NULL"), IS_NOT_NULL("IS NOT NULL"); + + private final String textRepr; + + YCQLUnaryPostfixOperator(String textRepr) { + this.textRepr = textRepr; + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + + public static YCQLUnaryPostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + } + + public enum YCQLUnaryPrefixOperator implements Operator { + + NOT("NOT"), PLUS("+"), MINUS("-"); + + private final String textRepr; + + YCQLUnaryPrefixOperator(String textRepr) { + this.textRepr = textRepr; + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + + public static YCQLUnaryPrefixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + } + + public enum YCQLBinaryLogicalOperator implements Operator { + + AND, OR; + + @Override + public String getTextRepresentation() { + return toString(); + } + + public static Operator getRandom() { + return Randomly.fromOptions(values()); + } + + } + + public enum YCQLBinaryArithmeticOperator implements Operator { + ADD("+"), SUB("-"), MULT("*"), DIV("/"); + + private String textRepr; + + YCQLBinaryArithmeticOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + + } + + public enum YCQLBinaryComparisonOperator implements Operator { + + EQUALS("="), GREATER(">"), GREATER_EQUALS(">="), SMALLER("<"), SMALLER_EQUALS("<="), NOT_EQUALS("!="); + + private final String textRepr; + + YCQLBinaryComparisonOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static Operator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + + } + + public YCQLFunction generateArgsForAggregate(YCQLAggregateFunction aggregateFunction) { + return new YCQLFunction(generateExpressions(aggregateFunction.getNrArgs()), + aggregateFunction); + } + + public YCQLExpression generateAggregate() { + YCQLAggregateFunction aggrFunc = YCQLAggregateFunction.getRandom(); + return generateArgsForAggregate(aggrFunc); + } + + @Override + public YCQLExpression negatePredicate(YCQLExpression predicate) { + return new YCQLUnaryPrefixOperation(predicate, YCQLUnaryPrefixOperator.NOT); + } + + @Override + public YCQLExpression isNull(YCQLExpression expr) { + return new YCQLUnaryPostfixOperation(expr, YCQLUnaryPostfixOperator.IS_NULL); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/gen/YCQLIndexGenerator.java b/src/sqlancer/yugabyte/ycql/gen/YCQLIndexGenerator.java new file mode 100644 index 000000000..ab03316cc --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/gen/YCQLIndexGenerator.java @@ -0,0 +1,55 @@ +package sqlancer.yugabyte.ycql.gen; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLColumn; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLTable; +import sqlancer.yugabyte.ycql.YCQLToStringVisitor; +import sqlancer.yugabyte.ycql.ast.YCQLExpression; + +public final class YCQLIndexGenerator { + + private YCQLIndexGenerator() { + } + + public static SQLQueryAdapter getQuery(YCQLGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + sb.append("CREATE "); + if (Randomly.getBoolean()) { + errors.add("Cant create unique index, table contains duplicate data on indexed column(s)"); + sb.append("UNIQUE "); + } + sb.append("INDEX "); + sb.append(Randomly.fromOptions("i0", "i1", "i2", "i3", "i4")); + sb.append(" ON "); + YCQLTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + sb.append("("); + List columns = table.getRandomNonEmptyColumnSubset(); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(columns.get(i).getName()); + } + sb.append(")"); + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + YCQLExpression expr = new YCQLExpressionGenerator(globalState).setColumns(table.getColumns()) + .generateExpression(); + sb.append(YCQLToStringVisitor.asString(expr)); + } + errors.add("Query timed out after PT2S"); + errors.add("Invalid SQL Statement"); + errors.add("Invalid CQL Statement"); + errors.add( + "Invalid Table Definition. Transactions cannot be enabled in an index of a table without transactions enabled."); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/gen/YCQLInsertGenerator.java b/src/sqlancer/yugabyte/ycql/gen/YCQLInsertGenerator.java new file mode 100644 index 000000000..a1159d310 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/gen/YCQLInsertGenerator.java @@ -0,0 +1,65 @@ +package sqlancer.yugabyte.ycql.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.common.gen.AbstractInsertGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.yugabyte.ycql.YCQLErrors; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLColumn; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLTable; +import sqlancer.yugabyte.ycql.YCQLToStringVisitor; + +public class YCQLInsertGenerator extends AbstractInsertGenerator { + + private final YCQLGlobalState globalState; + private final ExpectedErrors errors = new ExpectedErrors(); + + public YCQLInsertGenerator(YCQLGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter getQuery(YCQLGlobalState globalState) { + return new YCQLInsertGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + sb.append("INSERT INTO "); + YCQLTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getColumns(); + sb.append(globalState.getDatabaseName()).append(".").append(table.getName()); + sb.append("("); + sb.append(columns.stream().map(AbstractTableColumn::getName).collect(Collectors.joining(", "))); + sb.append(")"); + sb.append(" VALUES "); + insertColumns(columns); + + errors.add("Invalid Arguments"); + errors.add("Null Argument for Primary Key"); + + YCQLErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } + + @Override + protected void insertColumns(List columns) { + sb.append("("); + for (int nrColumn = 0; nrColumn < columns.size(); nrColumn++) { + if (nrColumn != 0) { + sb.append(", "); + } + insertValue(columns.get(nrColumn)); + } + sb.append(")"); + } + + @Override + protected void insertValue(YCQLColumn columnYCQL) { + // TODO: select a more meaningful value + sb.append(YCQLToStringVisitor.asString(new YCQLExpressionGenerator(globalState).generateConstant())); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/gen/YCQLRandomQuerySynthesizer.java b/src/sqlancer/yugabyte/ycql/gen/YCQLRandomQuerySynthesizer.java new file mode 100644 index 000000000..f7692ebb0 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/gen/YCQLRandomQuerySynthesizer.java @@ -0,0 +1,53 @@ +package sqlancer.yugabyte.ycql.gen; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLTable; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLTables; +import sqlancer.yugabyte.ycql.ast.YCQLConstant; +import sqlancer.yugabyte.ycql.ast.YCQLExpression; +import sqlancer.yugabyte.ycql.ast.YCQLSelect; +import sqlancer.yugabyte.ycql.ast.YCQLTableReference; + +public final class YCQLRandomQuerySynthesizer { + + private YCQLRandomQuerySynthesizer() { + } + + public static YCQLSelect generateSelect(YCQLGlobalState globalState, int nrColumns) { + YCQLTables targetTables = globalState.getSchema().getRandomTableNonEmptyTables(); + YCQLExpressionGenerator gen = new YCQLExpressionGenerator(globalState).setColumns(targetTables.getColumns()); + YCQLSelect select = new YCQLSelect(); + List columns = new ArrayList<>(); + for (int i = 0; i < nrColumns; i++) { + YCQLExpression expression = gen.generateExpression(); + columns.add(expression); + } + select.setFetchColumns(columns); + List tables = targetTables.getTables(); + Optional table = tables.stream().map(t -> new YCQLTableReference(t)).findFirst(); + select.setFromList(table.stream().collect(Collectors.toList())); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression()); + } + if (Randomly.getBoolean()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + if (Randomly.getBoolean()) { + select.setGroupByExpressions(Randomly.nonEmptySubset(select.getFetchColumns())); + } + if (Randomly.getBoolean()) { + select.setLimitClause(YCQLConstant.createIntConstant(Randomly.getNotCachedInteger(0, Integer.MAX_VALUE))); + } + if (Randomly.getBoolean()) { + select.setOffsetClause(YCQLConstant.createIntConstant(Randomly.getNotCachedInteger(0, Integer.MAX_VALUE))); + } + return select; + } + +} diff --git a/src/sqlancer/yugabyte/ycql/gen/YCQLTableGenerator.java b/src/sqlancer/yugabyte/ycql/gen/YCQLTableGenerator.java new file mode 100644 index 000000000..148a1cdbc --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/gen/YCQLTableGenerator.java @@ -0,0 +1,57 @@ +package sqlancer.yugabyte.ycql.gen; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLColumn; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLCompositeDataType; + +public class YCQLTableGenerator { + + public SQLQueryAdapter getQuery(YCQLGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + String tableName = globalState.getSchema().getFreeTableName(); + sb.append("CREATE TABLE "); + if (Randomly.getBoolean()) { + sb.append("IF NOT EXISTS "); + } + sb.append(tableName); + sb.append("("); + List columns = getNewColumns(); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(columns.get(i).getName()); + sb.append(" "); + sb.append(columns.get(i).getType()); + // todo PK, STATIC + } + errors.add("Query timed out after PT2S"); + errors.add("Invalid type for index"); + List primaryKeyColumns = Randomly.nonEmptySubset(columns); + sb.append(", PRIMARY KEY("); + sb.append(primaryKeyColumns.stream().map(AbstractTableColumn::getName).collect(Collectors.joining(", "))); + sb.append(")"); + sb.append(")"); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + + private static List getNewColumns() { + List columns = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + String columnName = String.format("c%d", i); + YCQLCompositeDataType columnType = YCQLCompositeDataType.getRandom(); + columns.add(new YCQLColumn(columnName, columnType, false, false)); + } + return columns; + } + +} diff --git a/src/sqlancer/yugabyte/ycql/gen/YCQLUpdateGenerator.java b/src/sqlancer/yugabyte/ycql/gen/YCQLUpdateGenerator.java new file mode 100644 index 000000000..a6c855cf3 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/gen/YCQLUpdateGenerator.java @@ -0,0 +1,59 @@ +package sqlancer.yugabyte.ycql.gen; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.gen.AbstractUpdateGenerator; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ycql.YCQLErrors; +import sqlancer.yugabyte.ycql.YCQLProvider.YCQLGlobalState; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLColumn; +import sqlancer.yugabyte.ycql.YCQLSchema.YCQLTable; +import sqlancer.yugabyte.ycql.YCQLToStringVisitor; +import sqlancer.yugabyte.ycql.ast.YCQLExpression; + +public final class YCQLUpdateGenerator extends AbstractUpdateGenerator { + + private final YCQLGlobalState globalState; + private YCQLExpressionGenerator gen; + + private YCQLUpdateGenerator(YCQLGlobalState globalState) { + this.globalState = globalState; + } + + public static SQLQueryAdapter getQuery(YCQLGlobalState globalState) { + return new YCQLUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + YCQLTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + List columns = table.getRandomNonEmptyColumnSubset(); + gen = new YCQLExpressionGenerator(globalState).setColumns(table.getColumns()); + sb.append("UPDATE "); + sb.append(table.getName()); + sb.append(" SET "); + updateColumns(columns); + errors.add("Invalid Arguments"); + errors.add("Invalid CQL Statement"); + errors.add("Invalid SQL Statement"); + errors.add("Datatype Mismatch"); + errors.add("Null Argument for Primary Key"); + errors.add("Missing Argument for Primary Key"); + + YCQLErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } + + @Override + protected void updateValue(YCQLColumn column) { + YCQLExpression expr; + if (Randomly.getBooleanWithSmallProbability()) { + expr = gen.generateExpression(); + YCQLErrors.addExpressionErrors(errors); + } else { + expr = gen.generateConstant(); + } + sb.append(YCQLToStringVisitor.asString(expr)); + } + +} diff --git a/src/sqlancer/yugabyte/ycql/test/YCQLFuzzer.java b/src/sqlancer/yugabyte/ycql/test/YCQLFuzzer.java new file mode 100644 index 000000000..7c21b7db8 --- /dev/null +++ b/src/sqlancer/yugabyte/ycql/test/YCQLFuzzer.java @@ -0,0 +1,73 @@ +package sqlancer.yugabyte.ycql.test; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ycql.YCQLProvider; +import sqlancer.yugabyte.ycql.YCQLToStringVisitor; +import sqlancer.yugabyte.ycql.gen.YCQLRandomQuerySynthesizer; + +public class YCQLFuzzer implements TestOracle { + private final YCQLProvider.YCQLGlobalState globalState; + private final List testQueries; + private final ExpectedErrors errors = new ExpectedErrors(); + + public YCQLFuzzer(YCQLProvider.YCQLGlobalState globalState) { + this.globalState = globalState; + + errors.add("Query timed out after PT2S"); + errors.add("Datatype Mismatch"); + errors.add("Invalid CQL Statement"); + errors.add("Invalid SQL Statement"); + errors.add("Invalid Arguments"); + errors.add("Invalid Function Call"); + + testQueries = new ArrayList<>(); + + testQueries.add(new SelectQuery()); + testQueries.add(new ActionQuery(YCQLProvider.Action.UPDATE)); + testQueries.add(new ActionQuery(YCQLProvider.Action.DELETE)); + testQueries.add(new ActionQuery(YCQLProvider.Action.INSERT)); + } + + @Override + public void check() throws Exception { + Query s = testQueries.get(globalState.getRandomly().getInteger(0, testQueries.size())); + globalState.executeStatement(s.getQuery(globalState, errors)); + globalState.getManager().incrementSelectQueryCount(); + } + + private static class Query { + public SQLQueryAdapter getQuery(YCQLProvider.YCQLGlobalState state, ExpectedErrors errors) throws Exception { + throw new IllegalAccessException("Should be implemented"); + }; + } + + private static class ActionQuery extends Query { + private final YCQLProvider.Action action; + + ActionQuery(YCQLProvider.Action action) { + this.action = action; + } + + @Override + public SQLQueryAdapter getQuery(YCQLProvider.YCQLGlobalState state, ExpectedErrors errors) throws Exception { + return action.getQuery(state); + } + } + + private static class SelectQuery extends Query { + + @Override + public SQLQueryAdapter getQuery(YCQLProvider.YCQLGlobalState state, ExpectedErrors errors) throws Exception { + return new SQLQueryAdapter( + YCQLToStringVisitor.asString( + YCQLRandomQuerySynthesizer.generateSelect(state, Randomly.smallNumber() + 1)) + ";", + errors); + } + } +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLCompoundDataType.java b/src/sqlancer/yugabyte/ysql/YSQLCompoundDataType.java new file mode 100644 index 000000000..3a8b8946d --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLCompoundDataType.java @@ -0,0 +1,45 @@ +package sqlancer.yugabyte.ysql; + +import java.util.Optional; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public final class YSQLCompoundDataType { + + private final YSQLDataType dataType; + private final YSQLCompoundDataType elemType; + private final Integer size; + + private YSQLCompoundDataType(YSQLDataType dataType, YSQLCompoundDataType elemType, Integer size) { + this.dataType = dataType; + this.elemType = elemType; + this.size = size; + } + + public static YSQLCompoundDataType create(YSQLDataType type, int size) { + return new YSQLCompoundDataType(type, null, size); + } + + public static YSQLCompoundDataType create(YSQLDataType type) { + return new YSQLCompoundDataType(type, null, null); + } + + public YSQLDataType getDataType() { + return dataType; + } + + public YSQLCompoundDataType getElemType() { + if (elemType == null) { + throw new AssertionError(); + } + return elemType; + } + + public Optional getSize() { + if (size == null) { + return Optional.empty(); + } else { + return Optional.of(size); + } + } +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLErrors.java b/src/sqlancer/yugabyte/ysql/YSQLErrors.java new file mode 100644 index 000000000..de692b352 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLErrors.java @@ -0,0 +1,234 @@ +package sqlancer.yugabyte.ysql; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class YSQLErrors { + + private YSQLErrors() { + } + + public static List getCommonFetchErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("An I/O error occurred while sending to the backend"); + errors.add("Conflicts with committed transaction"); + errors.add("cannot be changed"); + errors.add("SET TRANSACTION ISOLATION LEVEL must be called before any query"); + + errors.add("FULL JOIN is only supported with merge-joinable or hash-joinable join conditions"); + errors.add("but it cannot be referenced from this part of the query"); + errors.add("missing FROM-clause entry for table"); + + errors.add("canceling statement due to statement timeout"); + + errors.add("non-integer constant in"); + errors.add("must appear in the GROUP BY clause or be used in an aggregate function"); + errors.add("GROUP BY position"); + + return errors; + } + + public static void addCommonFetchErrors(ExpectedErrors errors) { + errors.addAll(getCommonFetchErrors()); + } + + public static List getCommonTableErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("PRIMARY KEY containing column of type 'INET' not yet supported"); + errors.add("PRIMARY KEY containing column of type 'VARBIT' not yet supported"); + errors.add("PRIMARY KEY containing column of type 'INT4RANGE' not yet supported"); + errors.add("INDEX on column of type 'INET' not yet supported"); + errors.add("INDEX on column of type 'VARBIT' not yet supported"); + errors.add("INDEX on column of type 'INT4RANGE' not yet supported"); + errors.add("is not commutative"); // exclude + errors.add("cannot be changed"); + errors.add("operator requires run-time type coercion"); // exclude + + return errors; + } + + public static void addCommonTableErrors(ExpectedErrors errors) { + errors.addAll(getCommonTableErrors()); + } + + public static List getCommonExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("syntax error at or near \"(\""); + errors.add("does not exist"); + errors.add("is not unique"); + errors.add("cannot be changed"); + errors.add("invalid reference to FROM-clause entry for table"); + + errors.add("Invalid column number"); + errors.add("specified more than once"); + errors.add("You might need to add explicit type casts"); + errors.add("invalid regular expression"); + errors.add("could not determine which collation to use"); + errors.add("invalid input syntax for integer"); + errors.add("invalid regular expression"); + errors.add("operator does not exist"); + errors.add("quantifier operand invalid"); + errors.add("collation mismatch"); + errors.add("collations are not supported"); + errors.add("operator is not unique"); + errors.add("is not a valid binary digit"); + errors.add("invalid hexadecimal digit"); + errors.add("invalid hexadecimal data: odd number of digits"); + errors.add("zero raised to a negative power is undefined"); + errors.add("cannot convert infinity to numeric"); + errors.add("division by zero"); + errors.add("invalid input syntax for type money"); + errors.add("invalid input syntax for type"); + errors.add("cannot cast type"); + errors.add("value overflows numeric format"); + errors.add("is of type boolean but expression is of type text"); + errors.add("a negative number raised to a non-integer power yields a complex result"); + errors.add("could not determine polymorphic type because input has type unknown"); + + errors.addAll(getToCharFunctionErrors()); + errors.addAll(getBitStringOperationErrors()); + errors.addAll(getFunctionErrors()); + errors.addAll(getCommonRangeExpressionErrors()); + errors.addAll(getCommonRegexExpressionErrors()); + + return errors; + } + + public static void addCommonExpressionErrors(ExpectedErrors errors) { + errors.addAll(getCommonExpressionErrors()); + } + + public static List getToCharFunctionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("multiple decimal points"); + errors.add("and decimal point together"); + errors.add("multiple decimal points"); + errors.add("cannot use \"S\" twice"); + errors.add("must be ahead of \"PR\""); + errors.add("cannot use \"S\" and \"PL\"/\"MI\"/\"SG\"/\"PR\" together"); + errors.add("cannot use \"S\" and \"SG\" together"); + errors.add("cannot use \"S\" and \"MI\" together"); + errors.add("cannot use \"S\" and \"PL\" together"); + errors.add("cannot use \"PR\" and \"S\"/\"PL\"/\"MI\"/\"SG\" together"); + errors.add("is not a number"); + + return errors; + } + + public static void addToCharFunctionErrors(ExpectedErrors errors) { + errors.addAll(getToCharFunctionErrors()); + } + + public static List getBitStringOperationErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("cannot XOR bit strings of different sizes"); + errors.add("cannot AND bit strings of different sizes"); + errors.add("cannot OR bit strings of different sizes"); + errors.add("must be type boolean, not type text"); + + return errors; + } + + public static void addBitStringOperationErrors(ExpectedErrors errors) { + errors.addAll(getBitStringOperationErrors()); + } + + public static List getFunctionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("out of valid range"); // get_bit/get_byte + errors.add("cannot take logarithm of a negative number"); + errors.add("cannot take logarithm of zero"); + errors.add("requested character too large for encoding"); // chr + errors.add("null character not permitted"); // chr + errors.add("requested character not valid for encoding"); // chr + errors.add("requested length too large"); // repeat + errors.add("invalid memory alloc request size"); // repeat + errors.add("encoding conversion from UTF8 to ASCII not supported"); // to_ascii + errors.add("negative substring length not allowed"); // substr + errors.add("invalid mask length"); // set_masklen + + return errors; + } + + public static void addFunctionErrors(ExpectedErrors errors) { + errors.addAll(getFunctionErrors()); + } + + public static List getCommonRegexExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("is not a valid hexadecimal digit"); + + return errors; + } + + public static void addCommonRegexExpressionErrors(ExpectedErrors errors) { + errors.addAll(getCommonRangeExpressionErrors()); + } + + public static List getCommonRangeExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("range lower bound must be less than or equal to range upper bound"); + errors.add("result of range difference would not be contiguous"); + errors.add("out of range"); + errors.add("malformed range literal"); + errors.add("result of range union would not be contiguous"); + + return errors; + } + + public static void addCommonRangeExpressionErrors(ExpectedErrors errors) { + errors.addAll(getCommonRangeExpressionErrors()); + } + + public static void addCommonInsertUpdateErrors(ExpectedErrors errors) { + errors.add("value too long for type character"); + errors.add("not found in view targetlist"); + } + + public static List getGroupingErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("non-integer constant in GROUP BY"); // TODO + errors.add("must appear in the GROUP BY clause or be used in an aggregate function"); + errors.add("is not in select list"); + errors.add("aggregate functions are not allowed in GROUP BY"); + + return errors; + } + + public static void addGroupingErrors(ExpectedErrors errors) { + errors.addAll(getGroupingErrors()); + } + + public static List getViewErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("already exists"); + errors.add("cannot drop columns from view"); + errors.add("non-integer constant in ORDER BY"); // TODO + errors.add("for SELECT DISTINCT, ORDER BY expressions must appear in select list"); // TODO + errors.add("cannot change data type of view column"); + errors.add("specified more than once"); // TODO + errors.add("materialized views must not use temporary tables or views"); + errors.add("does not have the form non-recursive-term UNION [ALL] recursive-term"); + errors.add("is not a view"); + errors.add("non-integer constant in DISTINCT ON"); + errors.add("SELECT DISTINCT ON expressions must match initial ORDER BY expressions"); + + return errors; + } + + public static void addViewErrors(ExpectedErrors errors) { + errors.addAll(getViewErrors()); + } +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLExpectedValueVisitor.java b/src/sqlancer/yugabyte/ysql/YSQLExpectedValueVisitor.java new file mode 100644 index 000000000..cf175ca39 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLExpectedValueVisitor.java @@ -0,0 +1,152 @@ +package sqlancer.yugabyte.ysql; + +import sqlancer.yugabyte.ysql.ast.YSQLAggregate; +import sqlancer.yugabyte.ysql.ast.YSQLBetweenOperation; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryLogicalOperation; +import sqlancer.yugabyte.ysql.ast.YSQLCastOperation; +import sqlancer.yugabyte.ysql.ast.YSQLColumnValue; +import sqlancer.yugabyte.ysql.ast.YSQLConstant; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; +import sqlancer.yugabyte.ysql.ast.YSQLFunction; +import sqlancer.yugabyte.ysql.ast.YSQLInOperation; +import sqlancer.yugabyte.ysql.ast.YSQLOrderByTerm; +import sqlancer.yugabyte.ysql.ast.YSQLPOSIXRegularExpression; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixText; +import sqlancer.yugabyte.ysql.ast.YSQLPrefixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; +import sqlancer.yugabyte.ysql.ast.YSQLSelect.YSQLFromTable; +import sqlancer.yugabyte.ysql.ast.YSQLSelect.YSQLSubquery; +import sqlancer.yugabyte.ysql.ast.YSQLSimilarTo; + +public final class YSQLExpectedValueVisitor implements YSQLVisitor { + + private static final int NR_TABS = 0; + private final StringBuilder sb = new StringBuilder(); + + private void print(YSQLExpression expr) { + YSQLToStringVisitor v = new YSQLToStringVisitor(); + v.visit(expr); + sb.append("\t".repeat(NR_TABS)); + sb.append(v.get()); + sb.append(" -- "); + sb.append(expr.getExpectedValue()); + sb.append("\n"); + } + + @Override + public void visit(YSQLConstant constant) { + print(constant); + } + + @Override + public void visit(YSQLPostfixOperation op) { + print(op); + visit(op.getExpression()); + } + + @Override + public void visit(YSQLColumnValue c) { + print(c); + } + + @Override + public void visit(YSQLPrefixOperation op) { + print(op); + visit(op.getExpression()); + } + + @Override + public void visit(YSQLSelect op) { + visit(op.getWhereClause()); + } + + @Override + public void visit(YSQLOrderByTerm op) { + + } + + @Override + public void visit(YSQLFunction f) { + print(f); + for (int i = 0; i < f.getArguments().length; i++) { + visit(f.getArguments()[i]); + } + } + + @Override + public void visit(YSQLCastOperation cast) { + print(cast); + visit(cast.getExpression()); + } + + @Override + public void visit(YSQLBetweenOperation op) { + print(op); + visit(op.getExpr()); + visit(op.getLeft()); + visit(op.getRight()); + } + + @Override + public void visit(YSQLInOperation op) { + print(op); + visit(op.getExpr()); + for (YSQLExpression right : op.getListElements()) { + visit(right); + } + } + + @Override + public void visit(YSQLPostfixText op) { + print(op); + visit(op.getExpr()); + } + + @Override + public void visit(YSQLAggregate op) { + print(op); + for (YSQLExpression expr : op.getArgs()) { + visit(expr); + } + } + + @Override + public void visit(YSQLSimilarTo op) { + print(op); + visit(op.getString()); + visit(op.getSimilarTo()); + if (op.getEscapeCharacter() != null) { + visit(op.getEscapeCharacter()); + } + } + + @Override + public void visit(YSQLPOSIXRegularExpression op) { + print(op); + visit(op.getString()); + visit(op.getRegex()); + } + + @Override + public void visit(YSQLFromTable from) { + print(from); + } + + @Override + public void visit(YSQLSubquery subquery) { + print(subquery); + } + + @Override + public void visit(YSQLBinaryLogicalOperation op) { + print(op); + visit(op.getLeft()); + visit(op.getRight()); + } + + public String get() { + return sb.toString(); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLGlobalState.java b/src/sqlancer/yugabyte/ysql/YSQLGlobalState.java new file mode 100644 index 000000000..093f06aed --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLGlobalState.java @@ -0,0 +1,127 @@ +package sqlancer.yugabyte.ysql; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; + +public class YSQLGlobalState extends SQLGlobalState { + + public static final char IMMUTABLE = 'i'; + public static final char STABLE = 's'; + public static final char VOLATILE = 'v'; + // store and allow filtering by function volatility classifications + private final Map functionsAndTypes = new HashMap<>(); + private List operators = Collections.emptyList(); + private List collates = Collections.emptyList(); + private List opClasses = Collections.emptyList(); + private List allowedFunctionTypes = Arrays.asList(IMMUTABLE, STABLE, VOLATILE); + + @Override + public void setConnection(SQLConnection con) { + super.setConnection(con); + try { + this.opClasses = getOpclasses(getConnection()); + this.operators = getOperators(getConnection()); + this.collates = getCollnames(getConnection()); + } catch (SQLException e) { + throw new AssertionError(e); + } + } + + @Override + public YSQLSchema readSchema() throws SQLException { + return YSQLSchema.fromConnection(getConnection(), getDatabaseName()); + } + + private List getCollnames(SQLConnection con) throws SQLException { + List opClasses = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s + .executeQuery("SELECT collname FROM pg_collation WHERE collname LIKE '%utf8' or collname = 'C';")) { + while (rs.next()) { + opClasses.add(rs.getString(1)); + } + } + } + return opClasses; + } + + private List getOpclasses(SQLConnection con) throws SQLException { + List opClasses = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery("select opcname FROM pg_opclass;")) { + while (rs.next()) { + opClasses.add(rs.getString(1)); + } + } + } + return opClasses; + } + + private List getOperators(SQLConnection con) throws SQLException { + List opClasses = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery("SELECT oprname FROM pg_operator;")) { + while (rs.next()) { + opClasses.add(rs.getString(1)); + } + } + } + return opClasses; + } + + public List getOperators() { + return operators; + } + + public String getRandomOperator() { + return Randomly.fromList(operators); + } + + public List getCollates() { + return collates; + } + + public String getRandomCollate() { + return Randomly.fromList(collates); + } + + public List getOpClasses() { + return opClasses; + } + + public String getRandomOpclass() { + return Randomly.fromList(opClasses); + } + + public void addFunctionAndType(String functionName, Character functionType) { + this.functionsAndTypes.put(functionName, functionType); + } + + public Map getFunctionsAndTypes() { + return this.functionsAndTypes; + } + + public void setDefaultAllowedFunctionTypes() { + this.allowedFunctionTypes = Arrays.asList(IMMUTABLE, STABLE, VOLATILE); + } + + public List getAllowedFunctionTypes() { + return this.allowedFunctionTypes; + } + + public void setAllowedFunctionTypes(List types) { + this.allowedFunctionTypes = types; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLOptions.java b/src/sqlancer/yugabyte/ysql/YSQLOptions.java new file mode 100644 index 000000000..71b886ba5 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLOptions.java @@ -0,0 +1,35 @@ +package sqlancer.yugabyte.ysql; + +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; + +@Parameters(separators = "=", commandDescription = "YSQL (default port: " + YSQLOptions.DEFAULT_PORT + + ", default host: " + YSQLOptions.DEFAULT_HOST) +public class YSQLOptions implements DBMSSpecificOptions { + public static final String DEFAULT_HOST = "localhost"; + public static final int DEFAULT_PORT = 5433; + + @Parameter(names = "--bulk-insert", description = "Specifies whether INSERT statements should be issued in bulk", arity = 1) + public boolean allowBulkInsert; + + @Parameter(names = "--oracle", description = "Specifies which test oracle should be used for YSQL") + public List oracle = Arrays.asList(YSQLOracleFactory.QUERY_PARTITIONING); + + @Parameter(names = "--test-collations", description = "Specifies whether to test different collations", arity = 1) + public boolean testCollations = true; + + @Parameter(names = "--connection-url", description = "Specifies the URL for connecting to the YSQL server", arity = 1) + public String connectionURL = String.format("jdbc:yugabytedb://%s:%d/yugabyte", YSQLOptions.DEFAULT_HOST, + YSQLOptions.DEFAULT_PORT); + + @Override + public List getTestOracleFactory() { + return oracle; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLOracleFactory.java b/src/sqlancer/yugabyte/ysql/YSQLOracleFactory.java new file mode 100644 index 000000000..b85f80593 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLOracleFactory.java @@ -0,0 +1,72 @@ +package sqlancer.yugabyte.ysql; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.OracleFactory; +import sqlancer.common.oracle.CompositeTestOracle; +import sqlancer.common.oracle.NoRECOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.yugabyte.ysql.gen.YSQLExpressionGenerator; +import sqlancer.yugabyte.ysql.oracle.YSQLCatalog; +import sqlancer.yugabyte.ysql.oracle.YSQLFuzzer; +import sqlancer.yugabyte.ysql.oracle.YSQLPivotedQuerySynthesisOracle; +import sqlancer.yugabyte.ysql.oracle.tlp.YSQLTLPAggregateOracle; +import sqlancer.yugabyte.ysql.oracle.tlp.YSQLTLPHavingOracle; +import sqlancer.yugabyte.ysql.oracle.tlp.YSQLTLPWhereOracle; + +public enum YSQLOracleFactory implements OracleFactory { + FUZZER { + @Override + public TestOracle create(YSQLGlobalState globalState) throws SQLException { + return new YSQLFuzzer(globalState); + } + }, + CATALOG { + @Override + public TestOracle create(YSQLGlobalState globalState) throws SQLException { + return new YSQLCatalog(globalState); + } + }, + NOREC { + @Override + public TestOracle create(YSQLGlobalState globalState) throws SQLException { + YSQLExpressionGenerator gen = new YSQLExpressionGenerator(globalState); + ExpectedErrors errors = ExpectedErrors.newErrors().with(YSQLErrors.getCommonExpressionErrors()) + .with(YSQLErrors.getCommonFetchErrors()).with("canceling statement due to statement timeout") + .build(); + return new NoRECOracle<>(globalState, gen, errors); + } + }, + PQS { + @Override + public TestOracle create(YSQLGlobalState globalState) throws SQLException { + return new YSQLPivotedQuerySynthesisOracle(globalState); + } + + @Override + public boolean requiresAllTablesToContainRows() { + return true; + } + }, + HAVING { + @Override + public TestOracle create(YSQLGlobalState globalState) throws SQLException { + return new YSQLTLPHavingOracle(globalState); + } + + }, + QUERY_PARTITIONING { + @Override + public TestOracle create(YSQLGlobalState globalState) throws SQLException { + List> oracles = new ArrayList<>(); + oracles.add(new YSQLTLPWhereOracle(globalState)); + oracles.add(new YSQLTLPHavingOracle(globalState)); + oracles.add(new YSQLTLPAggregateOracle(globalState)); + return new CompositeTestOracle(oracles, globalState); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLProvider.java b/src/sqlancer/yugabyte/ysql/YSQLProvider.java new file mode 100644 index 000000000..efcc7ec22 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLProvider.java @@ -0,0 +1,373 @@ +package sqlancer.yugabyte.ysql; + +import java.net.URI; +import java.net.URISyntaxException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.yugabyte.ysql.gen.YSQLAlterTableGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLAnalyzeGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLCommentGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLDeleteGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLDiscardGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLDropIndexGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLIndexGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLInsertGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLNotifyGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLSequenceGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLSetGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLTableGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLTableGroupGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLTransactionGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLTruncateGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLUpdateGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLVacuumGenerator; +import sqlancer.yugabyte.ysql.gen.YSQLViewGenerator; + +@AutoService(DatabaseProvider.class) +public class YSQLProvider extends SQLProviderAdapter { + + // TODO Due to yugabyte problems with parallel DDL we need this lock object + public static final Object DDL_LOCK = new Object(); + /** + * Generate only data types and expressions that are understood by PQS. + */ + public static boolean generateOnlyKnown; + protected String entryURL; + protected String username; + protected String password; + protected String entryPath; + protected String host; + protected int port; + protected String testURL; + protected String databaseName; + protected String createDatabaseCommand; + + public YSQLProvider() { + super(YSQLGlobalState.class, YSQLOptions.class); + } + + protected YSQLProvider(Class globalClass, Class optionClass) { + super(globalClass, optionClass); + } + + public static int mapActions(YSQLGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + int nrPerformed; + switch (a) { + case CREATE_INDEX: + nrPerformed = r.getInteger(0, 3); + break; + case DISCARD: + case DROP_INDEX: + nrPerformed = r.getInteger(0, 5); + break; + case COMMIT: + nrPerformed = r.getInteger(0, 0); + break; + case ALTER_TABLE: + nrPerformed = r.getInteger(0, 5); + break; + case RESET: + nrPerformed = r.getInteger(0, 3); + break; + case ANALYZE: + nrPerformed = r.getInteger(0, 3); + break; + case TABLEGROUP: + nrPerformed = r.getInteger(0, 3); + break; + case DELETE: + case RESET_ROLE: + case VACUUM: + case SET_CONSTRAINTS: + case SET: + case COMMENT_ON: + case NOTIFY: + case LISTEN: + case UNLISTEN: + case CREATE_SEQUENCE: + case TRUNCATE: + nrPerformed = r.getInteger(0, 2); + break; + case CREATE_VIEW: + nrPerformed = r.getInteger(0, 2); + break; + case UPDATE: + nrPerformed = r.getInteger(0, 10); + break; + case INSERT: + nrPerformed = r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + break; + default: + throw new AssertionError(a); + } + return nrPerformed; + + } + + @Override + public void generateDatabase(YSQLGlobalState globalState) throws Exception { + readFunctions(globalState); + createTables(globalState, Randomly.fromOptions(4, 5, 6)); + prepareTables(globalState); + } + + @Override + public SQLConnection createDatabase(YSQLGlobalState globalState) throws SQLException { + username = globalState.getOptions().getUserName(); + password = globalState.getOptions().getPassword(); + host = globalState.getOptions().getHost(); + port = globalState.getOptions().getPort(); + entryPath = "/yugabyte"; + entryURL = globalState.getDbmsSpecificOptions().connectionURL; + String entryDatabaseName = entryPath.substring(1); + databaseName = globalState.getDatabaseName(); + + if (host == null) { + host = YSQLOptions.DEFAULT_HOST; + } + if (port == MainOptions.NO_SET_PORT) { + port = YSQLOptions.DEFAULT_PORT; + } + + try { + URI uri = new URI(entryURL); + String userInfoURI = uri.getUserInfo(); + String pathURI = uri.getPath(); + if (userInfoURI != null) { + // username and password specified in URL take precedence + if (userInfoURI.contains(":")) { + String[] userInfo = userInfoURI.split(":", 2); + username = userInfo[0]; + password = userInfo[1]; + } else { + username = userInfoURI; + password = null; + } + int userInfoIndex = entryURL.indexOf(userInfoURI); + String preUserInfo = entryURL.substring(0, userInfoIndex); + String postUserInfo = entryURL.substring(userInfoIndex + userInfoURI.length() + 1); + entryURL = preUserInfo + postUserInfo; + } + if (pathURI != null) { + entryPath = pathURI; + } + if (host == null) { + host = uri.getHost(); + } + if (port == MainOptions.NO_SET_PORT) { + port = uri.getPort(); + } + entryURL = String.format("jdbc:yugabytedb://%s:%d/%s", host, port, entryDatabaseName); + } catch (URISyntaxException e) { + throw new AssertionError(e); + } + + createDatabaseSync(globalState, entryDatabaseName); + + int databaseIndex = entryURL.indexOf("/" + entryDatabaseName) + 1; + String preDatabaseName = entryURL.substring(0, databaseIndex); + String postDatabaseName = entryURL.substring(databaseIndex + entryDatabaseName.length()); + testURL = preDatabaseName + databaseName + postDatabaseName; + globalState.getState().logStatement(String.format("\\c %s;", databaseName)); + + return new SQLConnection(createConnectionSafely(testURL, username, password)); + } + + @Override + public String getDBMSName() { + return "ysql"; + } + + // for some reason yugabyte unable to create few databases simultaneously + private void createDatabaseSync(YSQLGlobalState globalState, String entryDatabaseName) throws SQLException { + synchronized (DDL_LOCK) { + exceptionLessSleep(5000); + + Connection con = createConnectionSafely(entryURL, username, password); + globalState.getState().logStatement(String.format("\\c %s;", entryDatabaseName)); + globalState.getState().logStatement("DROP DATABASE IF EXISTS " + databaseName + " WITH (FORCE)"); + createDatabaseCommand = getCreateDatabaseCommand(globalState); + globalState.getState().logStatement(createDatabaseCommand); + try (Statement s = con.createStatement()) { + s.execute("DROP DATABASE IF EXISTS " + databaseName + " WITH (FORCE)"); + } + try (Statement s = con.createStatement()) { + s.execute(createDatabaseCommand); + } + con.close(); + } + } + + private Connection createConnectionSafely(String entryURL, String user, String password) { + Connection con = null; + IllegalStateException lastException = new IllegalStateException("Empty exception"); + long endTime = System.currentTimeMillis() + 30000; + while (System.currentTimeMillis() < endTime) { + try { + con = DriverManager.getConnection(entryURL, user, password); + break; + } catch (SQLException throwables) { + lastException = new IllegalStateException(throwables); + } + } + + if (con == null) { + throw lastException; + } + + return con; + } + + protected void readFunctions(YSQLGlobalState globalState) throws SQLException { + SQLQueryAdapter query = new SQLQueryAdapter("SELECT proname, provolatile FROM pg_proc;"); + SQLancerResultSet rs = query.executeAndGet(globalState); + while (rs.next()) { + String functionName = rs.getString(1); + Character functionType = rs.getString(2).charAt(0); + globalState.addFunctionAndType(functionName, functionType); + } + } + + protected void createTables(YSQLGlobalState globalState, int numTables) throws Exception { + synchronized (DDL_LOCK) { + boolean prevCreationFailed = false; // small optimization - wait only after failed requests + while (globalState.getSchema().getDatabaseTables().size() < numTables) { + if (!prevCreationFailed) { + exceptionLessSleep(5000); + } + + try { + String tableName = DBMSCommon.createTableName(globalState.getSchema().getDatabaseTables().size()); + SQLQueryAdapter createTable = YSQLTableGenerator.generate(tableName, generateOnlyKnown, + globalState); + globalState.executeStatement(createTable); + prevCreationFailed = false; + } catch (IgnoreMeException e) { + prevCreationFailed = true; + } + } + } + } + + private void exceptionLessSleep(long timeout) { + try { + Thread.sleep(timeout); + } catch (InterruptedException e) { + throw new AssertionError(); + } + } + + protected void prepareTables(YSQLGlobalState globalState) throws Exception { + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + YSQLProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + globalState.executeStatement(new SQLQueryAdapter("COMMIT", true)); + globalState.executeStatement(new SQLQueryAdapter("SET SESSION statement_timeout = 15000;\n")); + } + + private String getCreateDatabaseCommand(YSQLGlobalState state) { + StringBuilder sb = new StringBuilder(); + sb.append("CREATE DATABASE ").append(databaseName).append(" "); + if (Randomly.getBoolean() && state.getDbmsSpecificOptions().testCollations) { + sb.append("WITH "); + if (Randomly.getBoolean()) { + sb.append("ENCODING '"); + sb.append(Randomly.fromOptions("utf8")); + sb.append("' "); + } + + if (Randomly.getBoolean()) { + // if (YugabyteBugs.bug11357) { + // throw new IgnoreMeException(); + // } + + sb.append("COLOCATED = true "); + } + + for (String lc : Arrays.asList("LC_COLLATE", "LC_CTYPE")) { + if (!state.getCollates().isEmpty() && Randomly.getBoolean()) { + sb.append(String.format(" %s = '%s'", lc, Randomly.fromList(state.getCollates()))); + } + } + sb.append(" TEMPLATE template0"); + + } + return sb.toString(); + } + + public enum Action implements AbstractAction { + ANALYZE(YSQLAnalyzeGenerator::create), // + ALTER_TABLE(g -> YSQLAlterTableGenerator.create(g.getSchema().getRandomTable(t -> !t.isView()), g)), // + COMMIT(g -> { + SQLQueryAdapter query; + if (Randomly.getBoolean()) { + query = new SQLQueryAdapter("COMMIT", true); + } else if (Randomly.getBoolean()) { + query = YSQLTransactionGenerator.executeBegin(); + } else { + query = new SQLQueryAdapter("ROLLBACK", true); + } + return query; + }), // + DELETE(YSQLDeleteGenerator::create), // + DISCARD(YSQLDiscardGenerator::create), // + DROP_INDEX(YSQLDropIndexGenerator::create), // + CREATE_INDEX(YSQLIndexGenerator::generate), // + INSERT(YSQLInsertGenerator::insert), // + UPDATE(YSQLUpdateGenerator::create), // + TRUNCATE(YSQLTruncateGenerator::create), // + TABLEGROUP(YSQLTableGroupGenerator::create), // + VACUUM(YSQLVacuumGenerator::create), // + SET(YSQLSetGenerator::create), // TODO insert yugabyte sets + SET_CONSTRAINTS((g) -> { + String sb = "SET CONSTRAINTS ALL " + Randomly.fromOptions("DEFERRED", "IMMEDIATE"); + return new SQLQueryAdapter(sb); + }), // + RESET_ROLE((g) -> new SQLQueryAdapter("RESET ROLE")), // + COMMENT_ON(YSQLCommentGenerator::generate), // + RESET((g) -> new SQLQueryAdapter("RESET ALL") /* + * https://www.postgres.org/docs/devel/sql-reset.html TODO: also + * configuration parameter + */), // + NOTIFY(YSQLNotifyGenerator::createNotify), // + LISTEN((g) -> YSQLNotifyGenerator.createListen()), // + UNLISTEN((g) -> YSQLNotifyGenerator.createUnlisten()), // + CREATE_SEQUENCE(YSQLSequenceGenerator::createSequence), // + CREATE_VIEW(YSQLViewGenerator::create); + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(YSQLGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLSchema.java b/src/sqlancer/yugabyte/ysql/YSQLSchema.java new file mode 100644 index 000000000..c75322af9 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLSchema.java @@ -0,0 +1,341 @@ +package sqlancer.yugabyte.ysql; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLIntegrityConstraintViolationException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.postgresql.util.PSQLException; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.DBMSCommon; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractRowValue; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.ast.YSQLConstant; + +public class YSQLSchema extends AbstractSchema { + + private final String databaseName; + + public YSQLSchema(List databaseTables, String databaseName) { + super(databaseTables); + this.databaseName = databaseName; + } + + public static YSQLDataType getColumnType(String typeString) { + switch (typeString) { + case "smallint": + case "integer": + case "bigint": + return YSQLDataType.INT; + case "boolean": + return YSQLDataType.BOOLEAN; + case "text": + case "character": + case "character varying": + case "name": + return YSQLDataType.TEXT; + case "numeric": + return YSQLDataType.DECIMAL; + case "double precision": + return YSQLDataType.FLOAT; + case "real": + return YSQLDataType.REAL; + case "int4range": + return YSQLDataType.RANGE; + case "money": + return YSQLDataType.MONEY; + case "bytea": + return YSQLDataType.BYTEA; + case "bit": + case "bit varying": + return YSQLDataType.BIT; + case "inet": + return YSQLDataType.INET; + default: + throw new AssertionError(typeString); + } + } + + public static YSQLSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + try { + List databaseTables = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery( + "SELECT table_name, table_schema, table_type, is_insertable_into FROM information_schema.tables WHERE table_schema='public' OR table_schema LIKE 'pg_temp_%' ORDER BY table_name;")) { + while (rs.next()) { + String tableName = rs.getString("table_name"); + String tableTypeSchema = rs.getString("table_schema"); + boolean isInsertable = rs.getBoolean("is_insertable_into"); + // TODO: also check insertable + // TODO: insert into view? + boolean isView = matchesViewName(tableName); // tableTypeStr.contains("VIEW") || + // tableTypeStr.contains("LOCAL TEMPORARY") && + // !isInsertable; + YSQLTable.TableType tableType = getTableType(tableTypeSchema); + List databaseColumns = getTableColumns(con, tableName); + List indexes = getIndexes(con, tableName); + List statistics = getStatistics(con); + YSQLTable t = new YSQLTable(tableName, databaseColumns, indexes, tableType, statistics, isView, + isInsertable); + for (YSQLColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + } + } + } + return new YSQLSchema(databaseTables, databaseName); + } catch (SQLIntegrityConstraintViolationException e) { + throw new AssertionError(e); + } + } + + protected static List getStatistics(SQLConnection con) throws SQLException { + List statistics = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery("SELECT stxname FROM pg_statistic_ext ORDER BY stxname;")) { + while (rs.next()) { + statistics.add(new YSQLStatisticsObject(rs.getString("stxname"))); + } + } + } + return statistics; + } + + protected static YSQLTable.TableType getTableType(String tableTypeStr) throws AssertionError { + YSQLTable.TableType tableType; + if (tableTypeStr.contentEquals("public")) { + tableType = YSQLTable.TableType.STANDARD; + } else if (tableTypeStr.startsWith("pg_temp")) { + tableType = YSQLTable.TableType.TEMPORARY; + } else { + throw new AssertionError(tableTypeStr); + } + return tableType; + } + + protected static List getIndexes(SQLConnection con, String tableName) throws SQLException { + List indexes = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String + .format("SELECT indexname FROM pg_indexes WHERE tablename='%s' ORDER BY indexname;", tableName))) { + while (rs.next()) { + String indexName = rs.getString("indexname"); + if (DBMSCommon.matchesIndexName(indexName)) { + indexes.add(YSQLIndex.create(indexName)); + } + } + } + } + return indexes; + } + + protected static List getTableColumns(SQLConnection con, String tableName) throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s + .executeQuery("select column_name, data_type from INFORMATION_SCHEMA.COLUMNS where table_name = '" + + tableName + "' ORDER BY column_name")) { + while (rs.next()) { + String columnName = rs.getString("column_name"); + String dataType = rs.getString("data_type"); + YSQLColumn c = new YSQLColumn(columnName, getColumnType(dataType)); + columns.add(c); + } + } + } + return columns; + } + + public boolean getDatabaseIsColocated(SQLConnection con) { + try (Statement s = con.createStatement(); ResultSet rs = s.executeQuery("SELECT yb_is_database_colocated();")) { + rs.next(); + String result = rs.getString(1); + // The query will result in a 'f' for a non-colocated database + return !"f".equals(result); + + } catch (SQLException e) { + throw new AssertionError(e); + } + } + + public YSQLTables getRandomTableNonEmptyTables() { + return new YSQLTables(Randomly.nonEmptySubset(getDatabaseTables())); + } + + public String getDatabaseName() { + return databaseName; + } + + public enum YSQLDataType { + // TODO: 23.02.2022 Planned types + // SMALLINT, INT, BIGINT, NUMERIC, DECIMAL, REAL, DOUBLE_PRECISION, VARCHAR, CHAR, TEXT, DATE, TIME, + // TIMESTAMP, TIMESTAMPZ, INTERVAL, INTEGER_ARR + INT, BOOLEAN, BYTEA, TEXT, DECIMAL, FLOAT, REAL, RANGE, MONEY, BIT, INET; + + public static YSQLDataType getRandomType() { + List dataTypes = new ArrayList<>(Arrays.asList(values())); + if (YSQLProvider.generateOnlyKnown) { + dataTypes.remove(YSQLDataType.DECIMAL); + dataTypes.remove(YSQLDataType.FLOAT); + dataTypes.remove(YSQLDataType.REAL); + dataTypes.remove(YSQLDataType.INET); + dataTypes.remove(YSQLDataType.RANGE); + dataTypes.remove(YSQLDataType.MONEY); + dataTypes.remove(YSQLDataType.BIT); + } + return Randomly.fromList(dataTypes); + } + } + + public static class YSQLColumn extends AbstractTableColumn { + + public YSQLColumn(String name, YSQLDataType columnType) { + super(name, null, columnType); + } + + public static YSQLColumn createDummy(String name) { + return new YSQLColumn(name, YSQLDataType.INT); + } + + } + + public static class YSQLTables extends AbstractTables { + + public YSQLTables(List tables) { + super(tables); + } + + public YSQLRowValue getRandomRowValue(SQLConnection con) throws SQLException { + String randomRow = String.format("SELECT %s FROM %s ORDER BY RANDOM() LIMIT 1", columnNamesAsString( + c -> c.getTable().getName() + "." + c.getName() + " AS " + c.getTable().getName() + c.getName()), + // columnNamesAsString(c -> "typeof(" + c.getTable().getName() + "." + + // c.getName() + ")") + tableNamesAsString()); + Map values = new HashMap<>(); + try (Statement s = con.createStatement()) { + ResultSet randomRowValues = s.executeQuery(randomRow); + if (!randomRowValues.next()) { + throw new AssertionError("could not find random row! " + randomRow + "\n"); + } + for (int i = 0; i < getColumns().size(); i++) { + YSQLColumn column = getColumns().get(i); + int columnIndex = randomRowValues.findColumn(column.getTable().getName() + column.getName()); + assert columnIndex == i + 1; + YSQLConstant constant; + if (randomRowValues.getString(columnIndex) == null) { + constant = YSQLConstant.createNullConstant(); + } else { + switch (column.getType()) { + case INT: + constant = YSQLConstant.createIntConstant(randomRowValues.getLong(columnIndex)); + break; + case BOOLEAN: + constant = YSQLConstant.createBooleanConstant(randomRowValues.getBoolean(columnIndex)); + break; + case TEXT: + constant = YSQLConstant.createTextConstant(randomRowValues.getString(columnIndex)); + break; + default: + throw new IgnoreMeException(); + } + } + values.put(column, constant); + } + assert !randomRowValues.next(); + return new YSQLRowValue(this, values); + } catch (PSQLException e) { + throw new IgnoreMeException(); + } + + } + + } + + public static class YSQLRowValue extends AbstractRowValue { + + protected YSQLRowValue(YSQLTables tables, Map values) { + super(tables, values); + } + + } + + public static class YSQLTable extends AbstractRelationalTable { + + private final TableType tableType; + private final List statistics; + private final boolean isInsertable; + + public YSQLTable(String tableName, List columns, List indexes, TableType tableType, + List statistics, boolean isView, boolean isInsertable) { + super(tableName, columns, indexes, isView); + this.statistics = statistics; + this.isInsertable = isInsertable; + this.tableType = tableType; + } + + public List getStatistics() { + return statistics; + } + + public TableType getTableType() { + return tableType; + } + + public boolean isInsertable() { + return isInsertable; + } + + public enum TableType { + STANDARD, TEMPORARY + } + + } + + public static final class YSQLStatisticsObject { + private final String name; + + public YSQLStatisticsObject(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + public static final class YSQLIndex extends TableIndex { + + private YSQLIndex(String indexName) { + super(indexName); + } + + public static YSQLIndex create(String indexName) { + return new YSQLIndex(indexName); + } + + @Override + public String getIndexName() { + if (super.getIndexName().contentEquals("PRIMARY")) { + return "`PRIMARY`"; + } else { + return super.getIndexName(); + } + } + + } + +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLToStringVisitor.java b/src/sqlancer/yugabyte/ysql/YSQLToStringVisitor.java new file mode 100644 index 000000000..54e718768 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLToStringVisitor.java @@ -0,0 +1,329 @@ +package sqlancer.yugabyte.ysql; + +import java.util.Optional; + +import sqlancer.Randomly; +import sqlancer.common.visitor.BinaryOperation; +import sqlancer.common.visitor.ToStringVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLAggregate; +import sqlancer.yugabyte.ysql.ast.YSQLBetweenOperation; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryLogicalOperation; +import sqlancer.yugabyte.ysql.ast.YSQLCastOperation; +import sqlancer.yugabyte.ysql.ast.YSQLColumnValue; +import sqlancer.yugabyte.ysql.ast.YSQLConstant; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; +import sqlancer.yugabyte.ysql.ast.YSQLFunction; +import sqlancer.yugabyte.ysql.ast.YSQLInOperation; +import sqlancer.yugabyte.ysql.ast.YSQLJoin; +import sqlancer.yugabyte.ysql.ast.YSQLJoin.YSQLJoinType; +import sqlancer.yugabyte.ysql.ast.YSQLOrderByTerm; +import sqlancer.yugabyte.ysql.ast.YSQLPOSIXRegularExpression; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixText; +import sqlancer.yugabyte.ysql.ast.YSQLPrefixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; +import sqlancer.yugabyte.ysql.ast.YSQLSelect.YSQLFromTable; +import sqlancer.yugabyte.ysql.ast.YSQLSelect.YSQLSubquery; +import sqlancer.yugabyte.ysql.ast.YSQLSimilarTo; + +public final class YSQLToStringVisitor extends ToStringVisitor implements YSQLVisitor { + + @Override + public void visitSpecific(YSQLExpression expr) { + YSQLVisitor.super.visit(expr); + } + + @Override + public String get() { + return sb.toString(); + } + + @Override + public void visit(YSQLConstant constant) { + sb.append(constant.getTextRepresentation()); + } + + @Override + public void visit(YSQLPostfixOperation op) { + sb.append("("); + visit(op.getExpression()); + sb.append(")"); + sb.append(" "); + sb.append(op.getOperatorTextRepresentation()); + } + + @Override + public void visit(YSQLColumnValue c) { + sb.append(c.getColumn().getFullQualifiedName()); + } + + @Override + public void visit(YSQLPrefixOperation op) { + sb.append(op.getTextRepresentation()); + sb.append(" ("); + visit(op.getExpression()); + sb.append(")"); + } + + @Override + public void visit(YSQLSelect s) { + sb.append("SELECT "); + switch (s.getSelectOption()) { + case DISTINCT: + sb.append("DISTINCT "); + if (s.getDistinctOnClause() != null) { + sb.append("ON ("); + visit(s.getDistinctOnClause()); + sb.append(") "); + } + break; + case ALL: + sb.append(Randomly.fromOptions("ALL ", "")); + break; + default: + throw new AssertionError(); + } + if (s.getFetchColumns() == null) { + sb.append("*"); + } else { + visit(s.getFetchColumns()); + } + sb.append(" FROM "); + visit(s.getFromList()); + + for (YSQLJoin j : s.getJoinClauses()) { + sb.append(" "); + switch (j.getType()) { + case INNER: + if (Randomly.getBoolean()) { + sb.append("INNER "); + } + sb.append("JOIN"); + break; + case LEFT: + sb.append("LEFT OUTER JOIN"); + break; + case RIGHT: + sb.append("RIGHT OUTER JOIN"); + break; + case FULL: + sb.append("FULL OUTER JOIN"); + break; + case CROSS: + sb.append("CROSS JOIN"); + break; + default: + throw new AssertionError(j.getType()); + } + sb.append(" "); + visit(j.getTableReference()); + if (j.getType() != YSQLJoinType.CROSS) { + sb.append(" ON "); + visit(j.getOnClause()); + } + } + + if (s.getWhereClause() != null) { + sb.append(" WHERE "); + visit(s.getWhereClause()); + } + if (!s.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(s.getGroupByExpressions()); + } + if (s.getHavingClause() != null) { + sb.append(" HAVING "); + visit(s.getHavingClause()); + + } + if (!s.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(s.getOrderByClauses()); + } + if (s.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(s.getLimitClause()); + } + + if (s.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(s.getOffsetClause()); + } + } + + @Override + public void visit(YSQLOrderByTerm op) { + visit(op.getExpr()); + sb.append(" "); + sb.append(op.getOrder()); + } + + @Override + public void visit(YSQLFunction f) { + sb.append(f.getFunctionName()); + sb.append("("); + int i = 0; + for (YSQLExpression arg : f.getArguments()) { + if (i++ != 0) { + sb.append(", "); + } + visit(arg); + } + sb.append(")"); + } + + @Override + public void visit(YSQLCastOperation cast) { + if (Randomly.getBoolean()) { + sb.append("CAST("); + visit(cast.getExpression()); + sb.append(" AS "); + appendType(cast); + sb.append(")"); + } else { + sb.append("("); + visit(cast.getExpression()); + sb.append(")::"); + appendType(cast); + } + } + + @Override + public void visit(YSQLBetweenOperation op) { + sb.append("("); + visit(op.getExpr()); + sb.append(") BETWEEN "); + if (op.isSymmetric()) { + sb.append("SYMMETRIC "); + } + sb.append("("); + visit(op.getLeft()); + sb.append(") AND ("); + visit(op.getRight()); + sb.append(")"); + } + + @Override + public void visit(YSQLInOperation op) { + sb.append("("); + visit(op.getExpr()); + sb.append(")"); + if (!op.isTrue()) { + sb.append(" NOT"); + } + sb.append(" IN ("); + visit(op.getListElements()); + sb.append(")"); + } + + @Override + public void visit(YSQLPostfixText op) { + visit(op.getExpr()); + sb.append(op.getText()); + } + + @Override + public void visit(YSQLAggregate op) { + sb.append(op.getFunction()); + sb.append("("); + visit(op.getArgs()); + sb.append(")"); + } + + @Override + public void visit(YSQLSimilarTo op) { + sb.append("("); + visit(op.getString()); + sb.append(" SIMILAR TO "); + visit(op.getSimilarTo()); + if (op.getEscapeCharacter() != null) { + visit(op.getEscapeCharacter()); + } + sb.append(")"); + } + + @Override + public void visit(YSQLPOSIXRegularExpression op) { + visit(op.getString()); + sb.append(op.getOp().getStringRepresentation()); + visit(op.getRegex()); + } + + @Override + public void visit(YSQLFromTable from) { + if (from.isOnly()) { + sb.append("ONLY "); + } + sb.append(from.getTable().getName()); + if (!from.isOnly() && Randomly.getBoolean()) { + sb.append("*"); + } + } + + @Override + public void visit(YSQLSubquery subquery) { + sb.append("("); + visit(subquery.getSelect()); + sb.append(") AS "); + sb.append(subquery.getName()); + } + + @Override + public void visit(YSQLBinaryLogicalOperation op) { + super.visit((BinaryOperation) op); + } + + private void appendType(YSQLCastOperation cast) { + YSQLCompoundDataType compoundType = cast.getCompoundType(); + switch (compoundType.getDataType()) { + case BOOLEAN: + sb.append("BOOLEAN"); + break; + case INT: // TODO support also other int types + sb.append("INT"); + break; + case TEXT: + // TODO: append TEXT, CHAR + sb.append(Randomly.fromOptions("VARCHAR")); + break; + case REAL: + sb.append("REAL"); + break; + case DECIMAL: + sb.append("DECIMAL"); + break; + case FLOAT: + sb.append("FLOAT"); + break; + case RANGE: + sb.append("int4range"); + break; + case MONEY: + sb.append("MONEY"); + break; + case INET: + sb.append("INET"); + break; + case BIT: + sb.append("BIT"); + break; + case BYTEA: + sb.append("BYTEA"); + break; + // if (Randomly.getBoolean()) { + // sb.append("("); + // sb.append(Randomly.getNotCachedInteger(1, 100)); + // sb.append(")"); + // } + default: + throw new AssertionError(cast.getType()); + } + Optional size = compoundType.getSize(); + if (size.isPresent()) { + sb.append("("); + sb.append(size.get()); + sb.append(")"); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/YSQLVisitor.java b/src/sqlancer/yugabyte/ysql/YSQLVisitor.java new file mode 100644 index 000000000..a73aade4c --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/YSQLVisitor.java @@ -0,0 +1,120 @@ +package sqlancer.yugabyte.ysql; + +import java.util.List; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.ast.YSQLAggregate; +import sqlancer.yugabyte.ysql.ast.YSQLBetweenOperation; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryLogicalOperation; +import sqlancer.yugabyte.ysql.ast.YSQLCastOperation; +import sqlancer.yugabyte.ysql.ast.YSQLColumnValue; +import sqlancer.yugabyte.ysql.ast.YSQLConstant; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; +import sqlancer.yugabyte.ysql.ast.YSQLFunction; +import sqlancer.yugabyte.ysql.ast.YSQLInOperation; +import sqlancer.yugabyte.ysql.ast.YSQLOrderByTerm; +import sqlancer.yugabyte.ysql.ast.YSQLPOSIXRegularExpression; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixText; +import sqlancer.yugabyte.ysql.ast.YSQLPrefixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; +import sqlancer.yugabyte.ysql.ast.YSQLSelect.YSQLFromTable; +import sqlancer.yugabyte.ysql.ast.YSQLSelect.YSQLSubquery; +import sqlancer.yugabyte.ysql.ast.YSQLSimilarTo; +import sqlancer.yugabyte.ysql.gen.YSQLExpressionGenerator; + +public interface YSQLVisitor { + + static String asString(YSQLExpression expr) { + YSQLToStringVisitor visitor = new YSQLToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } + + static String asExpectedValues(YSQLExpression expr) { + YSQLExpectedValueVisitor v = new YSQLExpectedValueVisitor(); + v.visit(expr); + return v.get(); + } + + static String getExpressionAsString(YSQLGlobalState globalState, YSQLDataType type, List columns) { + YSQLExpression expression = YSQLExpressionGenerator.generateExpression(globalState, columns, type); + YSQLToStringVisitor visitor = new YSQLToStringVisitor(); + visitor.visit(expression); + return visitor.get(); + } + + void visit(YSQLConstant constant); + + void visit(YSQLPostfixOperation op); + + void visit(YSQLColumnValue c); + + void visit(YSQLPrefixOperation op); + + void visit(YSQLSelect op); + + void visit(YSQLOrderByTerm op); + + void visit(YSQLFunction f); + + void visit(YSQLCastOperation cast); + + void visit(YSQLBetweenOperation op); + + void visit(YSQLInOperation op); + + void visit(YSQLPostfixText op); + + void visit(YSQLAggregate op); + + void visit(YSQLSimilarTo op); + + void visit(YSQLPOSIXRegularExpression op); + + void visit(YSQLFromTable from); + + void visit(YSQLSubquery subquery); + + void visit(YSQLBinaryLogicalOperation op); + + default void visit(YSQLExpression expression) { + if (expression instanceof YSQLConstant) { + visit((YSQLConstant) expression); + } else if (expression instanceof YSQLPostfixOperation) { + visit((YSQLPostfixOperation) expression); + } else if (expression instanceof YSQLColumnValue) { + visit((YSQLColumnValue) expression); + } else if (expression instanceof YSQLPrefixOperation) { + visit((YSQLPrefixOperation) expression); + } else if (expression instanceof YSQLSelect) { + visit((YSQLSelect) expression); + } else if (expression instanceof YSQLOrderByTerm) { + visit((YSQLOrderByTerm) expression); + } else if (expression instanceof YSQLFunction) { + visit((YSQLFunction) expression); + } else if (expression instanceof YSQLCastOperation) { + visit((YSQLCastOperation) expression); + } else if (expression instanceof YSQLBetweenOperation) { + visit((YSQLBetweenOperation) expression); + } else if (expression instanceof YSQLInOperation) { + visit((YSQLInOperation) expression); + } else if (expression instanceof YSQLAggregate) { + visit((YSQLAggregate) expression); + } else if (expression instanceof YSQLPostfixText) { + visit((YSQLPostfixText) expression); + } else if (expression instanceof YSQLSimilarTo) { + visit((YSQLSimilarTo) expression); + } else if (expression instanceof YSQLPOSIXRegularExpression) { + visit((YSQLPOSIXRegularExpression) expression); + } else if (expression instanceof YSQLFromTable) { + visit((YSQLFromTable) expression); + } else if (expression instanceof YSQLSubquery) { + visit((YSQLSubquery) expression); + } else { + throw new AssertionError(expression); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLAggregate.java b/src/sqlancer/yugabyte/ysql/ast/YSQLAggregate.java new file mode 100644 index 000000000..27daefcf9 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLAggregate.java @@ -0,0 +1,58 @@ +package sqlancer.yugabyte.ysql.ast; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.ast.FunctionNode; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.ast.YSQLAggregate.YSQLAggregateFunction; + +/** + * @see Built-in Aggregate Functions + */ +public class YSQLAggregate extends FunctionNode implements YSQLExpression { + + public YSQLAggregate(List args, YSQLAggregateFunction func) { + super(func, args); + } + + public enum YSQLAggregateFunction { + AVG(YSQLDataType.INT, YSQLDataType.FLOAT, YSQLDataType.REAL, YSQLDataType.DECIMAL), BIT_AND(YSQLDataType.INT), + BIT_OR(YSQLDataType.INT), BOOL_AND(YSQLDataType.BOOLEAN), BOOL_OR(YSQLDataType.BOOLEAN), + COUNT(YSQLDataType.INT), EVERY(YSQLDataType.BOOLEAN), MAX, MIN, + // STRING_AGG + SUM(YSQLDataType.INT, YSQLDataType.FLOAT, YSQLDataType.REAL, YSQLDataType.DECIMAL); + + private final YSQLDataType[] supportedReturnTypes; + + YSQLAggregateFunction(YSQLDataType... supportedReturnTypes) { + this.supportedReturnTypes = supportedReturnTypes.clone(); + } + + public static List getAggregates(YSQLDataType type) { + return Arrays.stream(values()).filter(p -> p.supportsReturnType(type)).collect(Collectors.toList()); + } + + public List getTypes(YSQLDataType returnType) { + return Collections.singletonList(returnType); + } + + public boolean supportsReturnType(YSQLDataType returnType) { + return Arrays.stream(supportedReturnTypes).anyMatch(t -> t == returnType) + || supportedReturnTypes.length == 0; + } + + public YSQLDataType getRandomReturnType() { + if (supportedReturnTypes.length == 0) { + return Randomly.fromOptions(YSQLDataType.getRandomType()); + } else { + return Randomly.fromOptions(supportedReturnTypes); + } + } + + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLAlias.java b/src/sqlancer/yugabyte/ysql/ast/YSQLAlias.java new file mode 100644 index 000000000..96e432370 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLAlias.java @@ -0,0 +1,35 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.common.visitor.UnaryOperation; + +public class YSQLAlias implements UnaryOperation, YSQLExpression { + + private final YSQLExpression expr; + private final String alias; + + public YSQLAlias(YSQLExpression expr, String alias) { + this.expr = expr; + this.alias = alias; + } + + @Override + public YSQLExpression getExpression() { + return expr; + } + + @Override + public String getOperatorRepresentation() { + return " as " + alias; + } + + @Override + public boolean omitBracketsWhenPrinting() { + return true; + } + + @Override + public OperatorKind getOperatorKind() { + return OperatorKind.POSTFIX; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLBetweenOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLBetweenOperation.java new file mode 100644 index 000000000..4a6c123d0 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLBetweenOperation.java @@ -0,0 +1,63 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public final class YSQLBetweenOperation implements YSQLExpression { + + private final YSQLExpression expr; + private final YSQLExpression left; + private final YSQLExpression right; + private final boolean isSymmetric; + + public YSQLBetweenOperation(YSQLExpression expr, YSQLExpression left, YSQLExpression right, boolean symmetric) { + this.expr = expr; + this.left = left; + this.right = right; + isSymmetric = symmetric; + } + + public YSQLExpression getExpr() { + return expr; + } + + public YSQLExpression getLeft() { + return left; + } + + public YSQLExpression getRight() { + return right; + } + + public boolean isSymmetric() { + return isSymmetric; + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + @Override + public YSQLConstant getExpectedValue() { + YSQLBinaryComparisonOperation leftComparison = new YSQLBinaryComparisonOperation(left, expr, + YSQLBinaryComparisonOperation.YSQLBinaryComparisonOperator.LESS_EQUALS); + YSQLBinaryComparisonOperation rightComparison = new YSQLBinaryComparisonOperation(expr, right, + YSQLBinaryComparisonOperation.YSQLBinaryComparisonOperator.LESS_EQUALS); + YSQLBinaryLogicalOperation andOperation = new YSQLBinaryLogicalOperation(leftComparison, rightComparison, + YSQLBinaryLogicalOperation.BinaryLogicalOperator.AND); + if (isSymmetric) { + YSQLBinaryComparisonOperation leftComparison2 = new YSQLBinaryComparisonOperation(right, expr, + YSQLBinaryComparisonOperation.YSQLBinaryComparisonOperator.LESS_EQUALS); + YSQLBinaryComparisonOperation rightComparison2 = new YSQLBinaryComparisonOperation(expr, left, + YSQLBinaryComparisonOperation.YSQLBinaryComparisonOperator.LESS_EQUALS); + YSQLBinaryLogicalOperation andOperation2 = new YSQLBinaryLogicalOperation(leftComparison2, rightComparison2, + YSQLBinaryLogicalOperation.BinaryLogicalOperator.AND); + YSQLBinaryLogicalOperation orOp = new YSQLBinaryLogicalOperation(andOperation, andOperation2, + YSQLBinaryLogicalOperation.BinaryLogicalOperator.OR); + return orOp.getExpectedValue(); + } else { + return andOperation.getExpectedValue(); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryArithmeticOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryArithmeticOperation.java new file mode 100644 index 000000000..a4385c86c --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryArithmeticOperation.java @@ -0,0 +1,106 @@ +package sqlancer.yugabyte.ysql.ast; + +import java.util.function.BinaryOperator; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryArithmeticOperation.YSQLBinaryOperator; + +public class YSQLBinaryArithmeticOperation extends BinaryOperatorNode + implements YSQLExpression { + + public YSQLBinaryArithmeticOperation(YSQLExpression left, YSQLExpression right, YSQLBinaryOperator op) { + super(left, right, op); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.INT; + } + + @Override + public YSQLConstant getExpectedValue() { + YSQLConstant leftExpected = getLeft().getExpectedValue(); + YSQLConstant rightExpected = getRight().getExpectedValue(); + if (leftExpected == null || rightExpected == null) { + return null; + } + return getOp().apply(leftExpected, rightExpected); + } + + public enum YSQLBinaryOperator implements Operator { + + ADDITION("+") { + @Override + public YSQLConstant apply(YSQLConstant left, YSQLConstant right) { + return applyBitOperation(left, right, Long::sum); + } + + }, + SUBTRACTION("-") { + @Override + public YSQLConstant apply(YSQLConstant left, YSQLConstant right) { + return applyBitOperation(left, right, (l, r) -> l - r); + } + }, + MULTIPLICATION("*") { + @Override + public YSQLConstant apply(YSQLConstant left, YSQLConstant right) { + return applyBitOperation(left, right, (l, r) -> l * r); + } + }, + DIVISION("/") { + @Override + public YSQLConstant apply(YSQLConstant left, YSQLConstant right) { + return applyBitOperation(left, right, (l, r) -> r == 0 ? -1 : l / r); + + } + + }, + MODULO("%") { + @Override + public YSQLConstant apply(YSQLConstant left, YSQLConstant right) { + return applyBitOperation(left, right, (l, r) -> r == 0 ? -1 : l % r); + + } + }, + EXPONENTIATION("^") { + @Override + public YSQLConstant apply(YSQLConstant left, YSQLConstant right) { + return null; + } + }; + + private final String textRepresentation; + + YSQLBinaryOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + private static YSQLConstant applyBitOperation(YSQLConstant left, YSQLConstant right, BinaryOperator op) { + if (left.isNull() || right.isNull()) { + return YSQLConstant.createNullConstant(); + } else { + long leftVal = left.cast(YSQLDataType.INT).asInt(); + long rightVal = right.cast(YSQLDataType.INT).asInt(); + long value = op.apply(leftVal, rightVal); + return YSQLConstant.createIntConstant(value); + } + } + + public static YSQLBinaryOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + public abstract YSQLConstant apply(YSQLConstant left, YSQLConstant right); + + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryBitOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryBitOperation.java new file mode 100644 index 000000000..0d1cb8f2e --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryBitOperation.java @@ -0,0 +1,46 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryBitOperation.YSQLBinaryBitOperator; + +public class YSQLBinaryBitOperation extends BinaryOperatorNode + implements YSQLExpression { + + public YSQLBinaryBitOperation(YSQLBinaryBitOperator op, YSQLExpression left, YSQLExpression right) { + super(left, right, op); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BIT; + } + + public enum YSQLBinaryBitOperator implements Operator { + CONCATENATION("||"), // + BITWISE_AND("&"), // + BITWISE_OR("|"), // + BITWISE_XOR("#"), // + BITWISE_SHIFT_LEFT("<<"), // + BITWISE_SHIFT_RIGHT(">>"); + + private final String text; + + YSQLBinaryBitOperator(String text) { + this.text = text; + } + + public static YSQLBinaryBitOperator getRandom() { + return Randomly.fromOptions(YSQLBinaryBitOperator.values()); + } + + @Override + public String getTextRepresentation() { + return text; + } + + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryComparisonOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryComparisonOperation.java new file mode 100644 index 000000000..8ced603ac --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryComparisonOperation.java @@ -0,0 +1,135 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryComparisonOperation.YSQLBinaryComparisonOperator; + +public class YSQLBinaryComparisonOperation extends BinaryOperatorNode + implements YSQLExpression { + + public YSQLBinaryComparisonOperation(YSQLExpression left, YSQLExpression right, YSQLBinaryComparisonOperator op) { + super(left, right, op); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + @Override + public YSQLConstant getExpectedValue() { + YSQLConstant leftExpectedValue = getLeft().getExpectedValue(); + YSQLConstant rightExpectedValue = getRight().getExpectedValue(); + if (leftExpectedValue == null || rightExpectedValue == null) { + return null; + } + return getOp().getExpectedValue(leftExpectedValue, rightExpectedValue); + } + + public enum YSQLBinaryComparisonOperator implements Operator { + EQUALS("=") { + @Override + public YSQLConstant getExpectedValue(YSQLConstant leftVal, YSQLConstant rightVal) { + return leftVal.isEquals(rightVal); + } + }, + IS_DISTINCT("IS DISTINCT FROM") { + @Override + public YSQLConstant getExpectedValue(YSQLConstant leftVal, YSQLConstant rightVal) { + return YSQLConstant + .createBooleanConstant(!IS_NOT_DISTINCT.getExpectedValue(leftVal, rightVal).asBoolean()); + } + }, + IS_NOT_DISTINCT("IS NOT DISTINCT FROM") { + @Override + public YSQLConstant getExpectedValue(YSQLConstant leftVal, YSQLConstant rightVal) { + if (leftVal.isNull()) { + return YSQLConstant.createBooleanConstant(rightVal.isNull()); + } else if (rightVal.isNull()) { + return YSQLConstant.createFalse(); + } else { + return leftVal.isEquals(rightVal); + } + } + }, + NOT_EQUALS("!=") { + @Override + public YSQLConstant getExpectedValue(YSQLConstant leftVal, YSQLConstant rightVal) { + YSQLConstant isEquals = leftVal.isEquals(rightVal); + if (isEquals.isBoolean()) { + return YSQLConstant.createBooleanConstant(!isEquals.asBoolean()); + } + return isEquals; + } + }, + LESS("<") { + @Override + public YSQLConstant getExpectedValue(YSQLConstant leftVal, YSQLConstant rightVal) { + return leftVal.isLessThan(rightVal); + } + }, + LESS_EQUALS("<=") { + @Override + public YSQLConstant getExpectedValue(YSQLConstant leftVal, YSQLConstant rightVal) { + YSQLConstant lessThan = leftVal.isLessThan(rightVal); + if (lessThan.isBoolean() && !lessThan.asBoolean()) { + return leftVal.isEquals(rightVal); + } else { + return lessThan; + } + } + }, + GREATER(">") { + @Override + public YSQLConstant getExpectedValue(YSQLConstant leftVal, YSQLConstant rightVal) { + YSQLConstant equals = leftVal.isEquals(rightVal); + if (equals.isBoolean() && equals.asBoolean()) { + return YSQLConstant.createFalse(); + } else { + YSQLConstant applyLess = leftVal.isLessThan(rightVal); + if (applyLess.isNull()) { + return YSQLConstant.createNullConstant(); + } + return YSQLPrefixOperation.PrefixOperator.NOT.getExpectedValue(applyLess); + } + } + }, + GREATER_EQUALS(">=") { + @Override + public YSQLConstant getExpectedValue(YSQLConstant leftVal, YSQLConstant rightVal) { + YSQLConstant equals = leftVal.isEquals(rightVal); + if (equals.isBoolean() && equals.asBoolean()) { + return YSQLConstant.createTrue(); + } else { + YSQLConstant applyLess = leftVal.isLessThan(rightVal); + if (applyLess.isNull()) { + return YSQLConstant.createNullConstant(); + } + return YSQLPrefixOperation.PrefixOperator.NOT.getExpectedValue(applyLess); + } + } + + }; + + private final String textRepresentation; + + YSQLBinaryComparisonOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static YSQLBinaryComparisonOperator getRandom() { + return Randomly.fromOptions(YSQLBinaryComparisonOperator.values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + public abstract YSQLConstant getExpectedValue(YSQLConstant leftVal, YSQLConstant rightVal); + + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryLogicalOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryLogicalOperation.java new file mode 100644 index 000000000..89cce762f --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryLogicalOperation.java @@ -0,0 +1,88 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryLogicalOperation.BinaryLogicalOperator; + +public class YSQLBinaryLogicalOperation extends BinaryOperatorNode + implements YSQLExpression { + + public YSQLBinaryLogicalOperation(YSQLExpression left, YSQLExpression right, BinaryLogicalOperator op) { + super(left, right, op); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + @Override + public YSQLConstant getExpectedValue() { + YSQLConstant leftExpectedValue = getLeft().getExpectedValue(); + YSQLConstant rightExpectedValue = getRight().getExpectedValue(); + if (leftExpectedValue == null || rightExpectedValue == null) { + return null; + } + return getOp().apply(leftExpectedValue, rightExpectedValue); + } + + public enum BinaryLogicalOperator implements Operator { + AND { + @Override + public YSQLConstant apply(YSQLConstant left, YSQLConstant right) { + YSQLConstant leftBool = left.cast(YSQLDataType.BOOLEAN); + YSQLConstant rightBool = right.cast(YSQLDataType.BOOLEAN); + if (leftBool.isNull()) { + if (rightBool.isNull()) { + return YSQLConstant.createNullConstant(); + } else { + if (rightBool.asBoolean()) { + return YSQLConstant.createNullConstant(); + } else { + return YSQLConstant.createFalse(); + } + } + } else if (!leftBool.asBoolean()) { + return YSQLConstant.createFalse(); + } + assert leftBool.asBoolean(); + if (rightBool.isNull()) { + return YSQLConstant.createNullConstant(); + } else { + return YSQLConstant.createBooleanConstant(rightBool.isBoolean() && rightBool.asBoolean()); + } + } + }, + OR { + @Override + public YSQLConstant apply(YSQLConstant left, YSQLConstant right) { + YSQLConstant leftBool = left.cast(YSQLDataType.BOOLEAN); + YSQLConstant rightBool = right.cast(YSQLDataType.BOOLEAN); + if (leftBool.isBoolean() && leftBool.asBoolean()) { + return YSQLConstant.createTrue(); + } + if (rightBool.isBoolean() && rightBool.asBoolean()) { + return YSQLConstant.createTrue(); + } + if (leftBool.isNull() || rightBool.isNull()) { + return YSQLConstant.createNullConstant(); + } + return YSQLConstant.createFalse(); + } + }; + + public static BinaryLogicalOperator getRandom() { + return Randomly.fromOptions(values()); + } + + public abstract YSQLConstant apply(YSQLConstant left, YSQLConstant right); + + @Override + public String getTextRepresentation() { + return toString(); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryRangeOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryRangeOperation.java new file mode 100644 index 000000000..4bf4a8ea6 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLBinaryRangeOperation.java @@ -0,0 +1,71 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryNode; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLBinaryRangeOperation extends BinaryNode implements YSQLExpression { + + private final String op; + + public YSQLBinaryRangeOperation(YSQLBinaryRangeComparisonOperator op, YSQLExpression left, YSQLExpression right) { + super(left, right); + this.op = op.getTextRepresentation(); + } + + public YSQLBinaryRangeOperation(YSQLBinaryRangeOperator op, YSQLExpression left, YSQLExpression right) { + super(left, right); + this.op = op.getTextRepresentation(); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + @Override + public String getOperatorRepresentation() { + return op; + } + + public enum YSQLBinaryRangeOperator implements Operator { + UNION("+"), INTERSECTION("*"), DIFFERENCE("-"); + + private final String textRepresentation; + + YSQLBinaryRangeOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static YSQLBinaryRangeOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepresentation; + } + + } + + public enum YSQLBinaryRangeComparisonOperator { + CONTAINS_RANGE_OR_ELEMENT("@>"), RANGE_OR_ELEMENT_IS_CONTAINED("<@"), OVERLAP("&&"), STRICT_LEFT_OF("<<"), + STRICT_RIGHT_OF(">>"), NOT_RIGHT_OF("&<"), NOT_LEFT_OF(">&"), ADJACENT("-|-"); + + private final String textRepresentation; + + YSQLBinaryRangeComparisonOperator(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static YSQLBinaryRangeComparisonOperator getRandom() { + return Randomly.fromOptions(values()); + } + + public String getTextRepresentation() { + return textRepresentation; + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLCastOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLCastOperation.java new file mode 100644 index 000000000..cacd8ad76 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLCastOperation.java @@ -0,0 +1,45 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.yugabyte.ysql.YSQLCompoundDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLCastOperation implements YSQLExpression { + + private final YSQLExpression expression; + private final YSQLCompoundDataType type; + + public YSQLCastOperation(YSQLExpression expression, YSQLCompoundDataType type) { + if (expression == null) { + throw new AssertionError(); + } + this.expression = expression; + this.type = type; + } + + @Override + public YSQLDataType getExpressionType() { + return type.getDataType(); + } + + @Override + public YSQLConstant getExpectedValue() { + YSQLConstant expectedValue = expression.getExpectedValue(); + if (expectedValue == null) { + return null; + } + return expectedValue.cast(type.getDataType()); + } + + public YSQLExpression getExpression() { + return expression; + } + + public YSQLDataType getType() { + return type.getDataType(); + } + + public YSQLCompoundDataType getCompoundType() { + return type; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLCollate.java b/src/sqlancer/yugabyte/ysql/ast/YSQLCollate.java new file mode 100644 index 000000000..82b164c54 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLCollate.java @@ -0,0 +1,33 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLCollate implements YSQLExpression { + + private final YSQLExpression expr; + private final String collate; + + public YSQLCollate(YSQLExpression expr, String collate) { + this.expr = expr; + this.collate = collate; + } + + public String getCollate() { + return collate; + } + + public YSQLExpression getExpr() { + return expr; + } + + @Override + public YSQLDataType getExpressionType() { + return expr.getExpressionType(); + } + + @Override + public YSQLConstant getExpectedValue() { + return null; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLColumnValue.java b/src/sqlancer/yugabyte/ysql/ast/YSQLColumnValue.java new file mode 100644 index 000000000..243bdaf57 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLColumnValue.java @@ -0,0 +1,34 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLColumnValue implements YSQLExpression { + + private final YSQLColumn c; + private final YSQLConstant expectedValue; + + public YSQLColumnValue(YSQLColumn c, YSQLConstant expectedValue) { + this.c = c; + this.expectedValue = expectedValue; + } + + public static YSQLColumnValue create(YSQLColumn c, YSQLConstant expected) { + return new YSQLColumnValue(c, expected); + } + + @Override + public YSQLDataType getExpressionType() { + return c.getType(); + } + + @Override + public YSQLConstant getExpectedValue() { + return expectedValue; + } + + public YSQLColumn getColumn() { + return c; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLConcatOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLConcatOperation.java new file mode 100644 index 000000000..a74263b22 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLConcatOperation.java @@ -0,0 +1,37 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.common.ast.BinaryNode; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLConcatOperation extends BinaryNode implements YSQLExpression { + + public YSQLConcatOperation(YSQLExpression left, YSQLExpression right) { + super(left, right); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.TEXT; + } + + @Override + public YSQLConstant getExpectedValue() { + YSQLConstant leftExpectedValue = getLeft().getExpectedValue(); + YSQLConstant rightExpectedValue = getRight().getExpectedValue(); + if (leftExpectedValue == null || rightExpectedValue == null) { + return null; + } + if (leftExpectedValue.isNull() || rightExpectedValue.isNull()) { + return YSQLConstant.createNullConstant(); + } + String leftStr = leftExpectedValue.cast(YSQLDataType.TEXT).getUnquotedTextRepresentation(); + String rightStr = rightExpectedValue.cast(YSQLDataType.TEXT).getUnquotedTextRepresentation(); + return YSQLConstant.createTextConstant(leftStr + rightStr); + } + + @Override + public String getOperatorRepresentation() { + return "||"; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLConstant.java b/src/sqlancer/yugabyte/ysql/ast/YSQLConstant.java new file mode 100644 index 000000000..9fd2eaad6 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLConstant.java @@ -0,0 +1,611 @@ +package sqlancer.yugabyte.ysql.ast; + +import java.math.BigDecimal; + +import sqlancer.IgnoreMeException; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public abstract class YSQLConstant implements YSQLExpression { + + public static YSQLConstant createNullConstant() { + return new YSQLNullConstant(); + } + + public static YSQLConstant createIntConstant(long val) { + return new IntConstant(val); + } + + public static YSQLConstant createBooleanConstant(boolean val) { + return new BooleanConstant(val); + } + + public static YSQLConstant createFalse() { + return createBooleanConstant(false); + } + + public static YSQLConstant createTrue() { + return createBooleanConstant(true); + } + + public static YSQLConstant createTextConstant(String string) { + return new StringConstant(string); + } + + public static YSQLConstant createByteConstant(String string) { + return new ByteConstant(string); + } + + public static YSQLConstant createDecimalConstant(BigDecimal bigDecimal) { + return new DecimalConstant(bigDecimal); + } + + public static YSQLConstant createFloatConstant(float val) { + return new FloatConstant(val); + } + + public static YSQLConstant createDoubleConstant(double val) { + return new DoubleConstant(val); + } + + public static YSQLConstant createRange(long left, boolean leftIsInclusive, long right, boolean rightIsInclusive) { + long realLeft; + long realRight; + if (left > right) { + realRight = left; + realLeft = right; + } else { + realLeft = left; + realRight = right; + } + return new RangeConstant(realLeft, leftIsInclusive, realRight, rightIsInclusive); + } + + public static YSQLExpression createBitConstant(long integer) { + return new BitConstant(integer); + } + + public static YSQLExpression createInetConstant(String val) { + return new InetConstant(val); + } + + public abstract String getTextRepresentation(); + + public abstract String getUnquotedTextRepresentation(); + + public String asString() { + throw new UnsupportedOperationException(this.toString()); + } + + public boolean isString() { + return false; + } + + @Override + public YSQLConstant getExpectedValue() { + return this; + } + + public boolean isNull() { + return false; + } + + public boolean asBoolean() { + throw new UnsupportedOperationException(this.toString()); + } + + public long asInt() { + throw new UnsupportedOperationException(this.toString()); + } + + public boolean isBoolean() { + return false; + } + + public abstract YSQLConstant isEquals(YSQLConstant rightVal); + + public boolean isInt() { + return false; + } + + protected abstract YSQLConstant isLessThan(YSQLConstant rightVal); + + @Override + public String toString() { + return getTextRepresentation(); + } + + public abstract YSQLConstant cast(YSQLDataType type); + + public static class BooleanConstant extends YSQLConstant { + + private final boolean value; + + public BooleanConstant(boolean value) { + this.value = value; + } + + @Override + public String getTextRepresentation() { + return value ? "TRUE" : "FALSE"; + } + + @Override + public String getUnquotedTextRepresentation() { + return getTextRepresentation(); + } + + @Override + public boolean asBoolean() { + return value; + } + + @Override + public boolean isBoolean() { + return true; + } + + @Override + public YSQLConstant isEquals(YSQLConstant rightVal) { + if (rightVal.isNull()) { + return YSQLConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return YSQLConstant.createBooleanConstant(value == rightVal.asBoolean()); + } else if (rightVal.isString()) { + return YSQLConstant.createBooleanConstant(value == rightVal.cast(YSQLDataType.BOOLEAN).asBoolean()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + protected YSQLConstant isLessThan(YSQLConstant rightVal) { + if (rightVal.isNull()) { + return YSQLConstant.createNullConstant(); + } else if (rightVal.isString()) { + return isLessThan(rightVal.cast(YSQLDataType.BOOLEAN)); + } else { + assert rightVal.isBoolean(); + return YSQLConstant.createBooleanConstant((value ? 1 : 0) < (rightVal.asBoolean() ? 1 : 0)); + } + } + + @Override + public YSQLConstant cast(YSQLDataType type) { + switch (type) { + case BOOLEAN: + return this; + case INT: + return YSQLConstant.createIntConstant(value ? 1 : 0); + case TEXT: + return YSQLConstant.createTextConstant(value ? "true" : "false"); + default: + return null; + } + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + } + + public static class YSQLNullConstant extends YSQLConstant { + + @Override + public String getTextRepresentation() { + return "NULL"; + } + + @Override + public String getUnquotedTextRepresentation() { + return getTextRepresentation(); + } + + @Override + public boolean isNull() { + return true; + } + + @Override + public YSQLConstant isEquals(YSQLConstant rightVal) { + return YSQLConstant.createNullConstant(); + } + + @Override + protected YSQLConstant isLessThan(YSQLConstant rightVal) { + return YSQLConstant.createNullConstant(); + } + + @Override + public YSQLConstant cast(YSQLDataType type) { + return YSQLConstant.createNullConstant(); + } + + @Override + public YSQLDataType getExpressionType() { + return null; + } + + } + + public static class StringConstant extends YSQLConstant { + + protected final String value; + + public StringConstant(String value) { + this.value = value; + } + + @Override + public String getTextRepresentation() { + return String.format("'%s'", value.replace("'", "''")); + } + + @Override + public String getUnquotedTextRepresentation() { + return value; + } + + @Override + public String asString() { + return value; + } + + @Override + public boolean isString() { + return true; + } + + @Override + public YSQLConstant isEquals(YSQLConstant rightVal) { + if (rightVal.isNull()) { + return YSQLConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return cast(YSQLDataType.INT).isEquals(rightVal.cast(YSQLDataType.INT)); + } else if (rightVal.isBoolean()) { + return cast(YSQLDataType.BOOLEAN).isEquals(rightVal.cast(YSQLDataType.BOOLEAN)); + } else if (rightVal.isString()) { + return YSQLConstant.createBooleanConstant(value.contentEquals(rightVal.asString())); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + protected YSQLConstant isLessThan(YSQLConstant rightVal) { + if (rightVal.isNull()) { + return YSQLConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return cast(YSQLDataType.INT).isLessThan(rightVal.cast(YSQLDataType.INT)); + } else if (rightVal.isBoolean()) { + return cast(YSQLDataType.BOOLEAN).isLessThan(rightVal.cast(YSQLDataType.BOOLEAN)); + } else if (rightVal.isString()) { + return YSQLConstant.createBooleanConstant(value.compareTo(rightVal.asString()) < 0); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + public YSQLConstant cast(YSQLDataType type) { + if (type == YSQLDataType.TEXT) { + return this; + } + String s = value.trim(); + switch (type) { + case BOOLEAN: + try { + return YSQLConstant.createBooleanConstant(Long.parseLong(s) != 0); + } catch (NumberFormatException e) { + } + switch (s.toUpperCase()) { + case "T": + case "TR": + case "TRU": + case "TRUE": + case "1": + case "YES": + case "YE": + case "Y": + case "ON": + return YSQLConstant.createTrue(); + case "F": + case "FA": + case "FAL": + case "FALS": + case "FALSE": + case "N": + case "NO": + case "OF": + case "OFF": + default: + return YSQLConstant.createFalse(); + } + case INT: + try { + return YSQLConstant.createIntConstant(Long.parseLong(s)); + } catch (NumberFormatException e) { + return YSQLConstant.createIntConstant(-1); + } + case TEXT: + return this; + default: + return null; + } + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.TEXT; + } + + } + + public static class IntConstant extends YSQLConstant { + + private final long val; + + public IntConstant(long val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + return String.valueOf(val); + } + + @Override + public String getUnquotedTextRepresentation() { + return getTextRepresentation(); + } + + @Override + public long asInt() { + return val; + } + + @Override + public YSQLConstant isEquals(YSQLConstant rightVal) { + if (rightVal.isNull()) { + return YSQLConstant.createNullConstant(); + } else if (rightVal.isBoolean()) { + return cast(YSQLDataType.BOOLEAN).isEquals(rightVal); + } else if (rightVal.isInt()) { + return YSQLConstant.createBooleanConstant(val == rightVal.asInt()); + } else if (rightVal.isString()) { + return YSQLConstant.createBooleanConstant(val == rightVal.cast(YSQLDataType.INT).asInt()); + } else { + throw new AssertionError(rightVal); + } + } + + @Override + public boolean isInt() { + return true; + } + + @Override + protected YSQLConstant isLessThan(YSQLConstant rightVal) { + if (rightVal.isNull()) { + return YSQLConstant.createNullConstant(); + } else if (rightVal.isInt()) { + return YSQLConstant.createBooleanConstant(val < rightVal.asInt()); + } else if (rightVal.isBoolean()) { + throw new AssertionError(rightVal); + } else if (rightVal.isString()) { + return YSQLConstant.createBooleanConstant(val < rightVal.cast(YSQLDataType.INT).asInt()); + } else { + throw new IgnoreMeException(); + } + + } + + @Override + public YSQLConstant cast(YSQLDataType type) { + switch (type) { + case BOOLEAN: + return YSQLConstant.createBooleanConstant(val != 0); + case INT: + return this; + case TEXT: + return YSQLConstant.createTextConstant(String.valueOf(val)); + default: + return null; + } + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.INT; + } + + } + + public static class ByteConstant extends StringConstant { + + public ByteConstant(String value) { + super(value); + } + + @Override + public String getTextRepresentation() { + return String.format("'%s'::bytea", value.replace("'", "''")); + } + } + + public abstract static class YSQLConstantBase extends YSQLConstant { + + @Override + public String getUnquotedTextRepresentation() { + return null; + } + + @Override + public YSQLConstant isEquals(YSQLConstant rightVal) { + return null; + } + + @Override + protected YSQLConstant isLessThan(YSQLConstant rightVal) { + return null; + } + + @Override + public YSQLConstant cast(YSQLDataType type) { + return null; + } + } + + public static class DecimalConstant extends YSQLConstantBase { + + private final BigDecimal val; + + public DecimalConstant(BigDecimal val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + return String.valueOf(val); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.DECIMAL; + } + + } + + public static class InetConstant extends YSQLConstantBase { + + private final String val; + + public InetConstant(String val) { + this.val = "'" + val + "'"; + } + + @Override + public String getTextRepresentation() { + return val; + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.INET; + } + + } + + public static class FloatConstant extends YSQLConstantBase { + + private final float val; + + public FloatConstant(float val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + if (Double.isFinite(val)) { + return String.valueOf(val); + } else { + return "'" + val + "'"; + } + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.FLOAT; + } + + } + + public static class DoubleConstant extends YSQLConstantBase { + + private final double val; + + public DoubleConstant(double val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + if (Double.isFinite(val)) { + return String.valueOf(val); + } else { + return "'" + val + "'"; + } + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.FLOAT; + } + + } + + public static class BitConstant extends YSQLConstantBase { + + private final long val; + + public BitConstant(long val) { + this.val = val; + } + + @Override + public String getTextRepresentation() { + return String.format("B'%s'", Long.toBinaryString(val)); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BIT; + } + + } + + public static class RangeConstant extends YSQLConstantBase { + + private final long left; + private final boolean leftIsInclusive; + private final long right; + private final boolean rightIsInclusive; + + public RangeConstant(long left, boolean leftIsInclusive, long right, boolean rightIsInclusive) { + this.left = left; + this.leftIsInclusive = leftIsInclusive; + this.right = right; + this.rightIsInclusive = rightIsInclusive; + } + + @Override + public String getTextRepresentation() { + StringBuilder sb = new StringBuilder(); + sb.append("'"); + if (leftIsInclusive) { + sb.append("["); + } else { + sb.append("("); + } + sb.append(left); + sb.append(","); + sb.append(right); + if (rightIsInclusive) { + sb.append("]"); + } else { + sb.append(")"); + } + sb.append("'"); + sb.append("::int4range"); + return sb.toString(); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.RANGE; + } + + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLExpression.java b/src/sqlancer/yugabyte/ysql/ast/YSQLExpression.java new file mode 100644 index 000000000..5f68afd5d --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLExpression.java @@ -0,0 +1,16 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public interface YSQLExpression extends Expression { + + default YSQLDataType getExpressionType() { + return null; + } + + default YSQLConstant getExpectedValue() { + return null; + } +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLFunction.java b/src/sqlancer/yugabyte/ysql/ast/YSQLFunction.java new file mode 100644 index 000000000..54349ebe9 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLFunction.java @@ -0,0 +1,283 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLFunction implements YSQLExpression { + + private final String func; + private final YSQLExpression[] args; + private final YSQLDataType returnType; + private YSQLFunctionWithResult functionWithKnownResult; + + public YSQLFunction(YSQLFunctionWithResult func, YSQLDataType returnType, YSQLExpression... args) { + functionWithKnownResult = func; + this.func = func.getName(); + this.returnType = returnType; + this.args = args.clone(); + } + + public YSQLFunction(YSQLFunctionWithUnknownResult f, YSQLDataType returnType, YSQLExpression... args) { + this.func = f.getName(); + this.returnType = returnType; + this.args = args.clone(); + } + + public String getFunctionName() { + return func; + } + + public YSQLExpression[] getArguments() { + return args.clone(); + } + + @Override + public YSQLDataType getExpressionType() { + return returnType; + } + + @Override + public YSQLConstant getExpectedValue() { + if (functionWithKnownResult == null) { + return null; + } + YSQLConstant[] constants = new YSQLConstant[args.length]; + for (int i = 0; i < constants.length; i++) { + constants[i] = args[i].getExpectedValue(); + if (constants[i] == null) { + return null; + } + } + return functionWithKnownResult.apply(constants, args); + } + + public enum YSQLFunctionWithResult { + ABS(1, "abs") { + @Override + public YSQLConstant apply(YSQLConstant[] evaluatedArgs, YSQLExpression... args) { + if (evaluatedArgs[0].isNull()) { + return YSQLConstant.createNullConstant(); + } else { + return YSQLConstant.createIntConstant(Math.abs(evaluatedArgs[0].cast(YSQLDataType.INT).asInt())); + } + } + + @Override + public boolean supportsReturnType(YSQLDataType type) { + return type == YSQLDataType.INT; + } + + @Override + public YSQLDataType[] getInputTypesForReturnType(YSQLDataType returnType, int nrArguments) { + return new YSQLDataType[] { returnType }; + } + + }, + LOWER(1, "lower") { + @Override + public YSQLConstant apply(YSQLConstant[] evaluatedArgs, YSQLExpression... args) { + if (evaluatedArgs[0].isNull()) { + return YSQLConstant.createNullConstant(); + } else { + String text = evaluatedArgs[0].asString(); + return YSQLConstant.createTextConstant(text.toLowerCase()); + } + } + + @Override + public boolean supportsReturnType(YSQLDataType type) { + return type == YSQLDataType.TEXT; + } + + @Override + public YSQLDataType[] getInputTypesForReturnType(YSQLDataType returnType, int nrArguments) { + return new YSQLDataType[] { YSQLDataType.TEXT }; + } + + }, + LENGTH(1, "length") { + @Override + public YSQLConstant apply(YSQLConstant[] evaluatedArgs, YSQLExpression... args) { + if (evaluatedArgs[0].isNull()) { + return YSQLConstant.createNullConstant(); + } + String text = evaluatedArgs[0].asString(); + return YSQLConstant.createIntConstant(text.length()); + } + + @Override + public boolean supportsReturnType(YSQLDataType type) { + return type == YSQLDataType.INT; + } + + @Override + public YSQLDataType[] getInputTypesForReturnType(YSQLDataType returnType, int nrArguments) { + return new YSQLDataType[] { YSQLDataType.TEXT }; + } + }, + UPPER(1, "upper") { + @Override + public YSQLConstant apply(YSQLConstant[] evaluatedArgs, YSQLExpression... args) { + if (evaluatedArgs[0].isNull()) { + return YSQLConstant.createNullConstant(); + } else { + String text = evaluatedArgs[0].asString(); + return YSQLConstant.createTextConstant(text.toUpperCase()); + } + } + + @Override + public boolean supportsReturnType(YSQLDataType type) { + return type == YSQLDataType.TEXT; + } + + @Override + public YSQLDataType[] getInputTypesForReturnType(YSQLDataType returnType, int nrArguments) { + return new YSQLDataType[] { YSQLDataType.TEXT }; + } + + }, + // NULL_IF(2, "nullif") { + // + // @Override + // public YSQLConstant apply(YSQLConstant[] evaluatedArgs, YSQLExpression[] args) { + // YSQLConstant equals = evaluatedArgs[0].isEquals(evaluatedArgs[1]); + // if (equals.isBoolean() && equals.asBoolean()) { + // return YSQLConstant.createNullConstant(); + // } else { + // // TODO: SELECT (nullif('1', FALSE)); yields '1', but should yield TRUE + // return evaluatedArgs[0]; + // } + // } + // + // @Override + // public boolean supportsReturnType(YSQLDataType type) { + // return true; + // } + // + // @Override + // public YSQLDataType[] getInputTypesForReturnType(YSQLDataType returnType, int nrArguments) { + // return getType(nrArguments, returnType); + // } + // + // @Override + // public boolean checkArguments(YSQLExpression[] constants) { + // for (YSQLExpression e : constants) { + // if (!(e instanceof YSQLNullConstant)) { + // return true; + // } + // } + // return false; + // } + // + // }, + NUM_NONNULLS(1, "num_nonnulls") { + @Override + public YSQLConstant apply(YSQLConstant[] args, YSQLExpression... origArgs) { + int nr = 0; + for (YSQLConstant c : args) { + if (!c.isNull()) { + nr++; + } + } + return YSQLConstant.createIntConstant(nr); + } + + @Override + public YSQLDataType[] getInputTypesForReturnType(YSQLDataType returnType, int nrArguments) { + return getRandomTypes(nrArguments); + } + + @Override + public boolean supportsReturnType(YSQLDataType type) { + return type == YSQLDataType.INT; + } + + @Override + public boolean isVariadic() { + return true; + } + + }, + NUM_NULLS(1, "num_nulls") { + @Override + public YSQLConstant apply(YSQLConstant[] args, YSQLExpression... origArgs) { + int nr = 0; + for (YSQLConstant c : args) { + if (c.isNull()) { + nr++; + } + } + return YSQLConstant.createIntConstant(nr); + } + + @Override + public YSQLDataType[] getInputTypesForReturnType(YSQLDataType returnType, int nrArguments) { + return getRandomTypes(nrArguments); + } + + @Override + public boolean supportsReturnType(YSQLDataType type) { + return type == YSQLDataType.INT; + } + + @Override + public boolean isVariadic() { + return true; + } + + }; + + final int nrArgs; + private final String functionName; + private final boolean variadic; + + YSQLFunctionWithResult(int nrArgs, String functionName) { + this.nrArgs = nrArgs; + this.functionName = functionName; + this.variadic = false; + } + + public YSQLDataType[] getRandomTypes(int nr) { + YSQLDataType[] types = new YSQLDataType[nr]; + for (int i = 0; i < types.length; i++) { + types[i] = YSQLDataType.getRandomType(); + } + return types; + } + + /** + * Gets the number of arguments if the function is non-variadic. If the function is variadic, the minimum number + * of arguments is returned. + * + * @return the number of arguments + */ + public int getNrArgs() { + return nrArgs; + } + + public abstract YSQLConstant apply(YSQLConstant[] evaluatedArgs, YSQLExpression... args); + + @Override + public String toString() { + return functionName; + } + + public boolean isVariadic() { + return variadic; + } + + public String getName() { + return functionName; + } + + public abstract boolean supportsReturnType(YSQLDataType type); + + public abstract YSQLDataType[] getInputTypesForReturnType(YSQLDataType returnType, int nrArguments); + + public boolean checkArguments(YSQLExpression... constants) { + return true; + } + + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLFunctionWithUnknownResult.java b/src/sqlancer/yugabyte/ysql/ast/YSQLFunctionWithUnknownResult.java new file mode 100644 index 000000000..8fbd1aa5d --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLFunctionWithUnknownResult.java @@ -0,0 +1,174 @@ +package sqlancer.yugabyte.ysql.ast; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.gen.YSQLExpressionGenerator; + +public enum YSQLFunctionWithUnknownResult { + + ABBREV("abbrev", YSQLDataType.TEXT, YSQLDataType.INET), + BROADCAST("broadcast", YSQLDataType.INET, YSQLDataType.INET), FAMILY("family", YSQLDataType.INT, YSQLDataType.INET), + HOSTMASK("hostmask", YSQLDataType.INET, YSQLDataType.INET), MASKLEN("masklen", YSQLDataType.INT, YSQLDataType.INET), + NETMASK("netmask", YSQLDataType.INET, YSQLDataType.INET), + SET_MASKLEN("set_masklen", YSQLDataType.INET, YSQLDataType.INET, YSQLDataType.INT), + TEXT("text", YSQLDataType.TEXT, YSQLDataType.INET), + INET_SAME_FAMILY("inet_same_family", YSQLDataType.BOOLEAN, YSQLDataType.INET, YSQLDataType.INET), + + // https://www.postgres.org/docs/devel/functions-admin.html#FUNCTIONS-ADMIN-SIGNAL-TABLE + // PG_RELOAD_CONF("pg_reload_conf", YSQLDataType.BOOLEAN), // too much output + // PG_ROTATE_LOGFILE("pg_rotate_logfile", YSQLDataType.BOOLEAN), prints warning + + // https://www.postgresql.org/docs/devel/functions-info.html#FUNCTIONS-INFO-SESSION-TABLE + CURRENT_DATABASE("current_database", YSQLDataType.TEXT), // name + // CURRENT_QUERY("current_query", YSQLDataType.TEXT), // can generate false positives + CURRENT_SCHEMA("current_schema", YSQLDataType.TEXT), // name + // CURRENT_SCHEMAS("current_schemas", YSQLDataType.TEXT, YSQLDataType.BOOLEAN), + INET_CLIENT_PORT("inet_client_port", YSQLDataType.INT), INET_SERVER_PORT("inet_server_port", YSQLDataType.INT), + PG_BACKEND_PID("pg_backend_pid", YSQLDataType.INT), PG_CURRENT_LOGFILE("pg_current_logfile", YSQLDataType.TEXT), + // PG_IS_OTHER_TEMP_SCHEMA("pg_is_other_temp_schema", YSQLDataType.BOOLEAN), + // PG_JIT_AVAILABLE("pg_is_other_temp_schema", YSQLDataType.BOOLEAN), + PG_NOTIFICATION_QUEUE_USAGE("pg_notification_queue_usage", YSQLDataType.REAL), + PG_TRIGGER_DEPTH("pg_trigger_depth", YSQLDataType.INT), VERSION("version", YSQLDataType.TEXT), + + // + TO_CHAR("to_char", YSQLDataType.TEXT, YSQLDataType.BYTEA, YSQLDataType.TEXT) { + @Override + public YSQLExpression[] getArguments(YSQLDataType returnType, YSQLExpressionGenerator gen, int depth) { + YSQLExpression[] args = super.getArguments(returnType, gen, depth); + args[0] = gen.generateExpression(YSQLDataType.getRandomType()); + return args; + } + }, + + // String functions + ASCII("ascii", YSQLDataType.INT, YSQLDataType.TEXT), + BTRIM("btrim", YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.TEXT), + CHR("chr", YSQLDataType.TEXT, YSQLDataType.INT), + CONVERT_FROM("convert_from", YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.TEXT) { + @Override + public YSQLExpression[] getArguments(YSQLDataType returnType, YSQLExpressionGenerator gen, int depth) { + YSQLExpression[] args = super.getArguments(returnType, gen, depth); + args[1] = YSQLConstant.createTextConstant("UTF8"); + return args; + } + }, + // concat + // segfault + BIT_LENGTH("bit_length", YSQLDataType.INT, YSQLDataType.BYTEA), + INITCAP("initcap", YSQLDataType.TEXT, YSQLDataType.TEXT), + LEFT("left", YSQLDataType.TEXT, YSQLDataType.INT, YSQLDataType.TEXT), + LOWER("lower", YSQLDataType.TEXT, YSQLDataType.TEXT), MD5("md5", YSQLDataType.TEXT, YSQLDataType.TEXT), + UPPER("upper", YSQLDataType.TEXT, YSQLDataType.TEXT), + // PG_CLIENT_ENCODING("pg_client_encoding", YSQLDataType.TEXT), + QUOTE_LITERAL("quote_literal", YSQLDataType.TEXT, YSQLDataType.TEXT), + QUOTE_IDENT("quote_ident", YSQLDataType.TEXT, YSQLDataType.TEXT), + REGEX_REPLACE("regexp_replace", YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.TEXT), + // todo mute repeat function because it may provide OOMs + // REPEAT("repeat", YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.INT), + REPLACE("replace", YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.TEXT), + REVERSE("reverse", YSQLDataType.TEXT, YSQLDataType.TEXT), + RIGHT("right", YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.INT), + RPAD("rpad", YSQLDataType.TEXT, YSQLDataType.INT, YSQLDataType.TEXT), + RTRIM("rtrim", YSQLDataType.TEXT, YSQLDataType.TEXT), + SPLIT_PART("split_part", YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.INT), + STRPOS("strpos", YSQLDataType.INT, YSQLDataType.TEXT, YSQLDataType.TEXT), + SUBSTR("substr", YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.INT, YSQLDataType.INT), + TO_ASCII("to_ascii", YSQLDataType.TEXT, YSQLDataType.TEXT), TO_HEX("to_hex", YSQLDataType.INT, YSQLDataType.TEXT), + TRANSLATE("translate", YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.TEXT, YSQLDataType.TEXT), + // mathematical functions + // https://www.postgresql.org/docs/9.5/functions-math.html + ABS("abs", YSQLDataType.REAL, YSQLDataType.REAL), CBRT("cbrt", YSQLDataType.REAL, YSQLDataType.REAL), + CEILING("ceiling", YSQLDataType.REAL), // + DEGREES("degrees", YSQLDataType.REAL), EXP("exp", YSQLDataType.REAL), LN("ln", YSQLDataType.REAL), + LOG("log", YSQLDataType.REAL), LOG2("log", YSQLDataType.REAL, YSQLDataType.REAL), PI("pi", YSQLDataType.REAL), + POWER("power", YSQLDataType.REAL, YSQLDataType.REAL), TRUNC("trunc", YSQLDataType.REAL, YSQLDataType.INT), + TRUNC2("trunc", YSQLDataType.REAL, YSQLDataType.INT, YSQLDataType.REAL), FLOOR("floor", YSQLDataType.REAL), + + // trigonometric functions - complete + // https://www.postgresql.org/docs/12/functions-math.html#FUNCTIONS-MATH-TRIG-TABLE + ACOS("acos", YSQLDataType.REAL), // + ACOSD("acosd", YSQLDataType.REAL), // + ASIN("asin", YSQLDataType.REAL), // + ASIND("asind", YSQLDataType.REAL), // + ATAN("atan", YSQLDataType.REAL), // + ATAND("atand", YSQLDataType.REAL), // + ATAN2("atan2", YSQLDataType.REAL, YSQLDataType.REAL), // + ATAN2D("atan2d", YSQLDataType.REAL, YSQLDataType.REAL), // + COS("cos", YSQLDataType.REAL), // + COSD("cosd", YSQLDataType.REAL), // + COT("cot", YSQLDataType.REAL), // + COTD("cotd", YSQLDataType.REAL), // + SIN("sin", YSQLDataType.REAL), // + SIND("sind", YSQLDataType.REAL), // + TAN("tan", YSQLDataType.REAL), // + TAND("tand", YSQLDataType.REAL), // + + // hyperbolic functions - complete + // https://www.postgresql.org/docs/12/functions-math.html#FUNCTIONS-MATH-HYP-TABLE + SINH("sinh", YSQLDataType.REAL), // + COSH("cosh", YSQLDataType.REAL), // + TANH("tanh", YSQLDataType.REAL), // + ASINH("asinh", YSQLDataType.REAL), // + ACOSH("acosh", YSQLDataType.REAL), // + ATANH("atanh", YSQLDataType.REAL), // + + // https://www.postgresql.org/docs/devel/functions-binarystring.html + GET_BIT("get_bit", YSQLDataType.INT, YSQLDataType.TEXT, YSQLDataType.INT), + GET_BYTE("get_byte", YSQLDataType.INT, YSQLDataType.TEXT, YSQLDataType.INT), + + // range functions + // https://www.postgresql.org/docs/devel/functions-range.html#RANGE-FUNCTIONS-TABLE + RANGE_LOWER("lower", YSQLDataType.INT, YSQLDataType.RANGE), // + RANGE_UPPER("upper", YSQLDataType.INT, YSQLDataType.RANGE), // + RANGE_ISEMPTY("isempty", YSQLDataType.BOOLEAN, YSQLDataType.RANGE), // + RANGE_LOWER_INC("lower_inc", YSQLDataType.BOOLEAN, YSQLDataType.RANGE), // + RANGE_UPPER_INC("upper_inc", YSQLDataType.BOOLEAN, YSQLDataType.RANGE), // + RANGE_LOWER_INF("lower_inf", YSQLDataType.BOOLEAN, YSQLDataType.RANGE), // + RANGE_UPPER_INF("upper_inf", YSQLDataType.BOOLEAN, YSQLDataType.RANGE), // + RANGE_MERGE("range_merge", YSQLDataType.RANGE, YSQLDataType.RANGE, YSQLDataType.RANGE), // + + // https://www.postgresql.org/docs/devel/functions-admin.html#FUNCTIONS-ADMIN-DBSIZE + GET_COLUMN_SIZE("get_column_size", YSQLDataType.INT, YSQLDataType.TEXT); + // PG_DATABASE_SIZE("pg_database_size", YSQLDataType.INT, YSQLDataType.INT); + // PG_SIZE_BYTES("pg_size_bytes", YSQLDataType.INT, YSQLDataType.TEXT); + + private final String functionName; + private final YSQLDataType returnType; + private final YSQLDataType[] argTypes; + + YSQLFunctionWithUnknownResult(String functionName, YSQLDataType returnType, YSQLDataType... indexType) { + this.functionName = functionName; + this.returnType = returnType; + this.argTypes = indexType.clone(); + } + + public static List getSupportedFunctions(YSQLDataType type) { + List functions = new ArrayList<>(); + for (YSQLFunctionWithUnknownResult func : values()) { + if (func.isCompatibleWithReturnType(type)) { + functions.add(func); + } + } + return functions; + } + + public boolean isCompatibleWithReturnType(YSQLDataType t) { + return t == returnType; + } + + public YSQLExpression[] getArguments(YSQLDataType returnType, YSQLExpressionGenerator gen, int depth) { + YSQLExpression[] args = new YSQLExpression[argTypes.length]; + for (int i = 0; i < args.length; i++) { + args[i] = gen.generateExpression(depth, argTypes[i]); + } + return args; + + } + + public String getName() { + return functionName; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLInOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLInOperation.java new file mode 100644 index 000000000..a8c5b9490 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLInOperation.java @@ -0,0 +1,65 @@ +package sqlancer.yugabyte.ysql.ast; + +import java.util.List; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLInOperation implements YSQLExpression { + + private final YSQLExpression expr; + private final List listElements; + private final boolean isTrue; + + public YSQLInOperation(YSQLExpression expr, List listElements, boolean isTrue) { + this.expr = expr; + this.listElements = listElements; + this.isTrue = isTrue; + } + + public YSQLExpression getExpr() { + return expr; + } + + public List getListElements() { + return listElements; + } + + public boolean isTrue() { + return isTrue; + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + @Override + public YSQLConstant getExpectedValue() { + YSQLConstant leftValue = expr.getExpectedValue(); + if (leftValue == null) { + return null; + } + if (leftValue.isNull()) { + return YSQLConstant.createNullConstant(); + } + boolean isNull = false; + for (YSQLExpression expr : getListElements()) { + YSQLConstant rightExpectedValue = expr.getExpectedValue(); + if (rightExpectedValue == null) { + return null; + } + if (rightExpectedValue.isNull()) { + isNull = true; + } else if (rightExpectedValue.isEquals(this.expr.getExpectedValue()).isBoolean() + && rightExpectedValue.isEquals(this.expr.getExpectedValue()).asBoolean()) { + return YSQLConstant.createBooleanConstant(isTrue); + } + } + + if (isNull) { + return YSQLConstant.createNullConstant(); + } else { + return YSQLConstant.createBooleanConstant(!isTrue); + } + } +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLJoin.java b/src/sqlancer/yugabyte/ysql/ast/YSQLJoin.java new file mode 100644 index 000000000..4edcbd11a --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLJoin.java @@ -0,0 +1,56 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.newast.Join; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; + +public class YSQLJoin implements YSQLExpression, Join { + + private final YSQLExpression tableReference; + private YSQLExpression onClause; + private final YSQLJoinType type; + + public YSQLJoin(YSQLExpression tableReference, YSQLExpression onClause, YSQLJoinType type) { + this.tableReference = tableReference; + this.onClause = onClause; + this.type = type; + } + + public YSQLExpression getTableReference() { + return tableReference; + } + + public YSQLExpression getOnClause() { + return onClause; + } + + public YSQLJoinType getType() { + return type; + } + + @Override + public YSQLDataType getExpressionType() { + throw new AssertionError(); + } + + @Override + public YSQLConstant getExpectedValue() { + throw new AssertionError(); + } + + public enum YSQLJoinType { + INNER, LEFT, RIGHT, FULL, CROSS; + + public static YSQLJoinType getRandom() { + return Randomly.fromOptions(values()); + } + + } + + @Override + public void setOnClause(YSQLExpression onClause) { + this.onClause = onClause; + } +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLOrderByTerm.java b/src/sqlancer/yugabyte/ysql/ast/YSQLOrderByTerm.java new file mode 100644 index 000000000..e57a347f9 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLOrderByTerm.java @@ -0,0 +1,42 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.Randomly; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLOrderByTerm implements YSQLExpression { + + private final YSQLOrder order; + private final YSQLExpression expr; + + public YSQLOrderByTerm(YSQLExpression expr, YSQLOrder order) { + this.expr = expr; + this.order = order; + } + + public YSQLOrder getOrder() { + return order; + } + + public YSQLExpression getExpr() { + return expr; + } + + @Override + public YSQLDataType getExpressionType() { + return null; + } + + @Override + public YSQLConstant getExpectedValue() { + throw new AssertionError(this); + } + + public enum YSQLOrder { + ASC, DESC; + + public static YSQLOrder getRandomOrder() { + return Randomly.fromOptions(YSQLOrder.values()); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLPOSIXRegularExpression.java b/src/sqlancer/yugabyte/ysql/ast/YSQLPOSIXRegularExpression.java new file mode 100644 index 000000000..975ac6aaf --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLPOSIXRegularExpression.java @@ -0,0 +1,65 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLPOSIXRegularExpression implements YSQLExpression { + + private final YSQLExpression string; + private final YSQLExpression regex; + private final POSIXRegex op; + + public YSQLPOSIXRegularExpression(YSQLExpression string, YSQLExpression regex, POSIXRegex op) { + this.string = string; + this.regex = regex; + this.op = op; + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + @Override + public YSQLConstant getExpectedValue() { + return null; + } + + public YSQLExpression getRegex() { + return regex; + } + + public YSQLExpression getString() { + return string; + } + + public POSIXRegex getOp() { + return op; + } + + public enum POSIXRegex implements Operator { + MATCH_CASE_SENSITIVE("~"), MATCH_CASE_INSENSITIVE("~*"), NOT_MATCH_CASE_SENSITIVE("!~"), + NOT_MATCH_CASE_INSENSITIVE("!~*"); + + private final String repr; + + POSIXRegex(String repr) { + this.repr = repr; + } + + public static POSIXRegex getRandom() { + return Randomly.fromOptions(values()); + } + + public String getStringRepresentation() { + return repr; + } + + @Override + public String getTextRepresentation() { + return toString(); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLPostfixOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLPostfixOperation.java new file mode 100644 index 000000000..65ab056c5 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLPostfixOperation.java @@ -0,0 +1,146 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLPostfixOperation implements YSQLExpression { + + private final YSQLExpression expr; + private final PostfixOperator op; + private final String operatorTextRepresentation; + + public YSQLPostfixOperation(YSQLExpression expr, PostfixOperator op) { + this.expr = expr; + this.operatorTextRepresentation = Randomly.fromOptions(op.textRepresentations); + this.op = op; + } + + public static YSQLExpression create(YSQLExpression expr, PostfixOperator op) { + return new YSQLPostfixOperation(expr, op); + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + @Override + public YSQLConstant getExpectedValue() { + YSQLConstant expectedValue = expr.getExpectedValue(); + if (expectedValue == null) { + return null; + } + return op.apply(expectedValue); + } + + public String getOperatorTextRepresentation() { + return operatorTextRepresentation; + } + + public YSQLExpression getExpression() { + return expr; + } + + public enum PostfixOperator implements Operator { + IS_NULL("IS NULL", "ISNULL") { + @Override + public YSQLConstant apply(YSQLConstant expectedValue) { + return YSQLConstant.createBooleanConstant(expectedValue.isNull()); + } + + @Override + public YSQLDataType[] getInputDataTypes() { + return YSQLDataType.values(); + } + + }, + IS_UNKNOWN("IS UNKNOWN") { + @Override + public YSQLConstant apply(YSQLConstant expectedValue) { + return YSQLConstant.createBooleanConstant(expectedValue.isNull()); + } + + @Override + public YSQLDataType[] getInputDataTypes() { + return new YSQLDataType[] { YSQLDataType.BOOLEAN }; + } + }, + + IS_NOT_NULL("IS NOT NULL", "NOTNULL") { + @Override + public YSQLConstant apply(YSQLConstant expectedValue) { + return YSQLConstant.createBooleanConstant(!expectedValue.isNull()); + } + + @Override + public YSQLDataType[] getInputDataTypes() { + return YSQLDataType.values(); + } + + }, + IS_NOT_UNKNOWN("IS NOT UNKNOWN") { + @Override + public YSQLConstant apply(YSQLConstant expectedValue) { + return YSQLConstant.createBooleanConstant(!expectedValue.isNull()); + } + + @Override + public YSQLDataType[] getInputDataTypes() { + return new YSQLDataType[] { YSQLDataType.BOOLEAN }; + } + }, + IS_TRUE("IS TRUE") { + @Override + public YSQLConstant apply(YSQLConstant expectedValue) { + if (expectedValue.isNull()) { + return YSQLConstant.createFalse(); + } else { + return YSQLConstant.createBooleanConstant(expectedValue.cast(YSQLDataType.BOOLEAN).asBoolean()); + } + } + + @Override + public YSQLDataType[] getInputDataTypes() { + return new YSQLDataType[] { YSQLDataType.BOOLEAN }; + } + + }, + IS_FALSE("IS FALSE") { + @Override + public YSQLConstant apply(YSQLConstant expectedValue) { + if (expectedValue.isNull()) { + return YSQLConstant.createFalse(); + } else { + return YSQLConstant.createBooleanConstant(!expectedValue.cast(YSQLDataType.BOOLEAN).asBoolean()); + } + } + + @Override + public YSQLDataType[] getInputDataTypes() { + return new YSQLDataType[] { YSQLDataType.BOOLEAN }; + } + + }; + + private final String[] textRepresentations; + + PostfixOperator(String... textRepresentations) { + this.textRepresentations = textRepresentations.clone(); + } + + public static PostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + public abstract YSQLConstant apply(YSQLConstant expectedValue); + + public abstract YSQLDataType[] getInputDataTypes(); + + @Override + public String getTextRepresentation() { + return toString(); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLPostfixText.java b/src/sqlancer/yugabyte/ysql/ast/YSQLPostfixText.java new file mode 100644 index 000000000..af9e64498 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLPostfixText.java @@ -0,0 +1,36 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLPostfixText implements YSQLExpression { + + private final YSQLExpression expr; + private final String text; + private final YSQLConstant expectedValue; + private final YSQLDataType type; + + public YSQLPostfixText(YSQLExpression expr, String text, YSQLConstant expectedValue, YSQLDataType type) { + this.expr = expr; + this.text = text; + this.expectedValue = expectedValue; + this.type = type; + } + + public YSQLExpression getExpr() { + return expr; + } + + public String getText() { + return text; + } + + @Override + public YSQLDataType getExpressionType() { + return type; + } + + @Override + public YSQLConstant getExpectedValue() { + return expectedValue; + } +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLPrefixOperation.java b/src/sqlancer/yugabyte/ysql/ast/YSQLPrefixOperation.java new file mode 100644 index 000000000..d34f07567 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLPrefixOperation.java @@ -0,0 +1,115 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.IgnoreMeException; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLPrefixOperation implements YSQLExpression { + + private final YSQLExpression expr; + private final PrefixOperator op; + + public YSQLPrefixOperation(YSQLExpression expr, PrefixOperator op) { + this.expr = expr; + this.op = op; + } + + @Override + public YSQLDataType getExpressionType() { + return op.getExpressionType(); + } + + @Override + public YSQLConstant getExpectedValue() { + YSQLConstant expectedValue = expr.getExpectedValue(); + if (expectedValue == null) { + return null; + } + return op.getExpectedValue(expectedValue); + } + + public YSQLDataType[] getInputDataTypes() { + return op.dataTypes; + } + + public String getTextRepresentation() { + return op.textRepresentation; + } + + public YSQLExpression getExpression() { + return expr; + } + + public enum PrefixOperator implements Operator { + NOT("NOT", YSQLDataType.BOOLEAN) { + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + @Override + protected YSQLConstant getExpectedValue(YSQLConstant expectedValue) { + if (expectedValue.isNull()) { + return YSQLConstant.createNullConstant(); + } else { + return YSQLConstant.createBooleanConstant(!expectedValue.cast(YSQLDataType.BOOLEAN).asBoolean()); + } + } + }, + UNARY_PLUS("+", YSQLDataType.INT) { + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.INT; + } + + @Override + protected YSQLConstant getExpectedValue(YSQLConstant expectedValue) { + // TODO: actual converts to double precision + return expectedValue; + } + + }, + UNARY_MINUS("-", YSQLDataType.INT) { + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.INT; + } + + @Override + protected YSQLConstant getExpectedValue(YSQLConstant expectedValue) { + if (expectedValue.isNull()) { + // TODO + throw new IgnoreMeException(); + } + if (expectedValue.isInt() && expectedValue.asInt() == Long.MIN_VALUE) { + throw new IgnoreMeException(); + } + try { + return YSQLConstant.createIntConstant(-expectedValue.asInt()); + } catch (UnsupportedOperationException e) { + return null; + } + } + + }; + + private final String textRepresentation; + private final YSQLDataType[] dataTypes; + + PrefixOperator(String textRepresentation, YSQLDataType... dataTypes) { + this.textRepresentation = textRepresentation; + this.dataTypes = dataTypes.clone(); + } + + public abstract YSQLDataType getExpressionType(); + + protected abstract YSQLConstant getExpectedValue(YSQLConstant expectedValue); + + @Override + public String getTextRepresentation() { + return toString(); + } + + } + +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLSelect.java b/src/sqlancer/yugabyte/ysql/ast/YSQLSelect.java new file mode 100644 index 000000000..fbf96ee45 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLSelect.java @@ -0,0 +1,145 @@ +package sqlancer.yugabyte.ysql.ast; + +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.YSQLVisitor; + +public class YSQLSelect extends SelectBase + implements YSQLExpression, Select { + + private SelectType selectOption = SelectType.ALL; + private List joinClauses = Collections.emptyList(); + private YSQLExpression distinctOnClause; + private ForClause forClause; + + public void setSelectType(SelectType fromOptions) { + this.setSelectOption(fromOptions); + } + + public SelectType getSelectOption() { + return selectOption; + } + + public void setSelectOption(SelectType fromOptions) { + this.selectOption = fromOptions; + } + + @Override + public YSQLDataType getExpressionType() { + return null; + } + + @Override + public List getJoinClauses() { + return joinClauses; + } + + @Override + public void setJoinClauses(List joinStatements) { + this.joinClauses = joinStatements; + + } + + public YSQLExpression getDistinctOnClause() { + return distinctOnClause; + } + + public void setDistinctOnClause(YSQLExpression distinctOnClause) { + if (selectOption != SelectType.DISTINCT) { + throw new IllegalArgumentException(); + } + this.distinctOnClause = distinctOnClause; + } + + public ForClause getForClause() { + return forClause; + } + + public void setForClause(ForClause forClause) { + this.forClause = forClause; + } + + public enum ForClause { + UPDATE("UPDATE"), NO_KEY_UPDATE("NO KEY UPDATE"), SHARE("SHARE"), KEY_SHARE("KEY SHARE"); + + private final String textRepresentation; + + ForClause(String textRepresentation) { + this.textRepresentation = textRepresentation; + } + + public static ForClause getRandom() { + return Randomly.fromOptions(values()); + } + + public String getTextRepresentation() { + return textRepresentation; + } + } + + public enum SelectType { + DISTINCT, ALL; + + public static SelectType getRandom() { + return Randomly.fromOptions(values()); + } + } + + public static class YSQLFromTable implements YSQLExpression { + private final YSQLTable t; + private final boolean only; + + public YSQLFromTable(YSQLTable t, boolean only) { + this.t = t; + this.only = only; + } + + public YSQLTable getTable() { + return t; + } + + public boolean isOnly() { + return only; + } + + @Override + public YSQLDataType getExpressionType() { + return null; + } + } + + public static class YSQLSubquery implements YSQLExpression { + private final YSQLSelect s; + private final String name; + + public YSQLSubquery(YSQLSelect s, String name) { + this.s = s; + this.name = name; + } + + public YSQLSelect getSelect() { + return s; + } + + public String getName() { + return name; + } + + @Override + public YSQLDataType getExpressionType() { + return null; + } + } + + @Override + public String asString() { + return YSQLVisitor.asString(this); + } +} diff --git a/src/sqlancer/yugabyte/ysql/ast/YSQLSimilarTo.java b/src/sqlancer/yugabyte/ysql/ast/YSQLSimilarTo.java new file mode 100644 index 000000000..794525bc2 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/ast/YSQLSimilarTo.java @@ -0,0 +1,39 @@ +package sqlancer.yugabyte.ysql.ast; + +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; + +public class YSQLSimilarTo implements YSQLExpression { + + private final YSQLExpression string; + private final YSQLExpression similarTo; + private final YSQLExpression escapeCharacter; + + public YSQLSimilarTo(YSQLExpression string, YSQLExpression similarTo, YSQLExpression escapeCharacter) { + this.string = string; + this.similarTo = similarTo; + this.escapeCharacter = escapeCharacter; + } + + public YSQLExpression getString() { + return string; + } + + public YSQLExpression getSimilarTo() { + return similarTo; + } + + public YSQLExpression getEscapeCharacter() { + return escapeCharacter; + } + + @Override + public YSQLDataType getExpressionType() { + return YSQLDataType.BOOLEAN; + } + + @Override + public YSQLConstant getExpectedValue() { + return null; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLAlterTableGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLAlterTableGenerator.java new file mode 100644 index 000000000..c4a4effd8 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLAlterTableGenerator.java @@ -0,0 +1,183 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.List; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; + +public class YSQLAlterTableGenerator { + + private final YSQLTable randomTable; + private final Randomly r; + private final YSQLGlobalState globalState; + + public YSQLAlterTableGenerator(YSQLTable randomTable, YSQLGlobalState globalState) { + this.randomTable = randomTable; + this.globalState = globalState; + this.r = globalState.getRandomly(); + } + + public static SQLQueryAdapter create(YSQLTable randomTable, YSQLGlobalState globalState) { + return new YSQLAlterTableGenerator(randomTable, globalState).generate(); + } + + public List getActions(ExpectedErrors errors) { + YSQLErrors.addCommonExpressionErrors(errors); + YSQLErrors.addCommonInsertUpdateErrors(errors); + YSQLErrors.addCommonTableErrors(errors); + errors.add("duplicate key value violates unique constraint"); + errors.add("cannot drop key column"); + errors.add("cannot drop desired object(s) because other objects depend on them"); + errors.add("invalid input syntax for"); + errors.add("cannot remove a key column"); + errors.add("it has pending trigger events"); + errors.add("could not open relation"); + errors.add("functions in index expression must be marked IMMUTABLE"); + errors.add("functions in index predicate must be marked IMMUTABLE"); + errors.add("has no default operator class for access method"); + errors.add("does not accept data type"); + errors.add("does not exist for access method"); + errors.add("could not find cast from"); + errors.add("does not exist"); // TODO: investigate + errors.add("constraints on permanent tables may reference only permanent tables"); + List action; + if (Randomly.getBoolean()) { + action = Randomly.nonEmptySubset(Action.values()); + } else { + // make it more likely that the ALTER TABLE succeeds + action = Randomly.subset(Randomly.smallNumber(), Action.values()); + } + if (randomTable.getColumns().size() == 1) { + action.remove(Action.ALTER_TABLE_DROP_COLUMN); + } + if (!randomTable.hasIndexes()) { + action.remove(Action.ADD_TABLE_CONSTRAINT_USING_INDEX); + } + if (action.isEmpty()) { + throw new IgnoreMeException(); + } + return action; + } + + public SQLQueryAdapter generate() { + ExpectedErrors errors = new ExpectedErrors(); + int i = 0; + List action = getActions(errors); + StringBuilder sb = new StringBuilder(); + sb.append("ALTER TABLE "); + if (Randomly.getBoolean()) { + sb.append(" ONLY"); + errors.add("cannot use ONLY for foreign key on partitioned table"); + } + sb.append(" "); + sb.append(randomTable.getName()); + sb.append(" "); + for (Action a : action) { + if (i++ != 0) { + sb.append(", "); + } + switch (a) { + case ALTER_TABLE_DROP_COLUMN: + sb.append("DROP "); + if (Randomly.getBoolean()) { + sb.append(" IF EXISTS "); + } + sb.append(randomTable.getRandomColumn().getName()); + errors.add("because other objects depend on it"); + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("RESTRICT", "CASCADE")); + } + errors.add("does not exist"); + errors.add("cannot drop column"); + errors.add("cannot drop key column"); + errors.add("cannot drop inherited column"); + break; + case ADD_TABLE_CONSTRAINT: + sb.append("ADD "); + sb.append("CONSTRAINT ").append(r.getAlphabeticChar()).append(" "); + YSQLCommon.addTableConstraint(sb, randomTable, globalState, errors); + errors.add("already exists"); + errors.add("multiple primary keys for table"); + errors.add("could not create unique index"); + errors.add("contains null values"); + errors.add("cannot cast type"); + errors.add("unsupported PRIMARY KEY constraint with partition key definition"); + errors.add("unsupported UNIQUE constraint with partition key definition"); + errors.add("insufficient columns in UNIQUE constraint definition"); + errors.add("which is part of the partition key"); + errors.add("out of range"); + errors.add("there is no unique constraint matching given keys for referenced table"); + errors.add("constraints on temporary tables may reference only temporary tables"); + errors.add("constraints on unlogged tables may reference only permanent or unlogged tables"); + errors.add("constraints on permanent tables may reference only permanent tables"); + errors.add("cannot reference partitioned table"); + errors.add("cannot be implemented"); + errors.add("violates foreign key constraint"); + errors.add("unsupported ON COMMIT and foreign key combination"); + errors.add("USING INDEX is not supported on partitioned tables"); + if (Randomly.getBoolean()) { + sb.append(" NOT VALID"); + errors.add("cannot be marked NOT VALID"); + errors.add("cannot add NOT VALID foreign key on partitioned table"); + } else { + errors.add("is violated by some row"); + } + break; + case ADD_TABLE_CONSTRAINT_USING_INDEX: + sb.append("ADD "); + sb.append("CONSTRAINT ").append(r.getAlphabeticChar()).append(" "); + sb.append(Randomly.fromOptions("UNIQUE", "PRIMARY KEY")); + sb.append(" USING INDEX "); + sb.append(randomTable.getRandomIndex().getIndexName()); + errors.add("already exists"); + errors.add("PRIMARY KEY containing column of type"); + errors.add("not valid"); + errors.add("is not a unique index"); + errors.add("is already associated with a constraint"); + errors.add("Cannot create a primary key or unique constraint using such an index"); + errors.add("multiple primary keys for table"); + errors.add("appears twice in unique constraint"); + errors.add("appears twice in primary key constraint"); + errors.add("contains null values"); + errors.add("insufficient columns in PRIMARY KEY constraint definition"); + errors.add("which is part of the partition key"); + break; + case DISABLE_ROW_LEVEL_SECURITY: + sb.append("DISABLE ROW LEVEL SECURITY"); + break; + case ENABLE_ROW_LEVEL_SECURITY: + sb.append("ENABLE ROW LEVEL SECURITY"); + break; + case FORCE_ROW_LEVEL_SECURITY: + sb.append("FORCE ROW LEVEL SECURITY"); + break; + case NO_FORCE_ROW_LEVEL_SECURITY: + sb.append("NO FORCE ROW LEVEL SECURITY"); + break; + default: + throw new AssertionError(a); + } + } + + return new SQLQueryAdapter(sb.toString(), errors, true); + } + + protected enum Action { + // ALTER_TABLE_ADD_COLUMN, // [ COLUMN ] column data_type [ COLLATE collation ] [ + // column_constraint [ ... ] ] + ALTER_TABLE_DROP_COLUMN, // DROP [ COLUMN ] [ IF EXISTS ] column [ RESTRICT | CASCADE ] + ADD_TABLE_CONSTRAINT, // ADD table_constraint [ NOT VALID ] + ADD_TABLE_CONSTRAINT_USING_INDEX, // ADD table_constraint_using_index + DISABLE_ROW_LEVEL_SECURITY, // DISABLE ROW LEVEL SECURITY + ENABLE_ROW_LEVEL_SECURITY, // ENABLE ROW LEVEL SECURITY + FORCE_ROW_LEVEL_SECURITY, // FORCE ROW LEVEL SECURITY + NO_FORCE_ROW_LEVEL_SECURITY, // NO FORCE ROW LEVEL SECURITY + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLAnalyzeGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLAnalyzeGenerator.java new file mode 100644 index 000000000..bee2ec5ef --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLAnalyzeGenerator.java @@ -0,0 +1,37 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; + +public final class YSQLAnalyzeGenerator { + + private YSQLAnalyzeGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + YSQLTable table = globalState.getSchema().getRandomTable(); + StringBuilder sb = new StringBuilder("ANALYZE"); + if (Randomly.getBoolean()) { + sb.append("("); + sb.append(" VERBOSE"); + sb.append(")"); + } + sb.append(" "); + sb.append(table.getName()); + if (Randomly.getBoolean()) { + sb.append("("); + sb.append(table.getRandomNonEmptyColumnSubset().stream().map(AbstractTableColumn::getName) + .collect(Collectors.joining(", "))); + sb.append(")"); + } + + return new SQLQueryAdapter(sb.toString(), ExpectedErrors.from("deadlock")); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLClusterGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLClusterGenerator.java new file mode 100644 index 000000000..6e852d167 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLClusterGenerator.java @@ -0,0 +1,32 @@ +package sqlancer.yugabyte.ysql.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; + +public final class YSQLClusterGenerator { + + private YSQLClusterGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + errors.add("there is no previously clustered index for table"); + errors.add("cannot cluster a partitioned table"); + errors.add("access method does not support clustering"); + StringBuilder sb = new StringBuilder("CLUSTER "); + if (Randomly.getBoolean()) { + YSQLTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + if (Randomly.getBoolean() && !table.getIndexes().isEmpty()) { + sb.append(" USING "); + sb.append(table.getRandomIndex().getIndexName()); + errors.add("cannot cluster on partial index"); + } + } + return new SQLQueryAdapter(sb.toString(), errors); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLCommentGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLCommentGenerator.java new file mode 100644 index 000000000..f020a8994 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLCommentGenerator.java @@ -0,0 +1,68 @@ +package sqlancer.yugabyte.ysql.gen; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; + +/** + * @see COMMENT + */ +public final class YSQLCommentGenerator { + + private YSQLCommentGenerator() { + } + + public static SQLQueryAdapter generate(YSQLGlobalState globalState) { + StringBuilder sb = new StringBuilder(); + sb.append("COMMENT ON "); + Action type = Randomly.fromOptions(Action.values()); + YSQLTable randomTable = globalState.getSchema().getRandomTable(); + switch (type) { + case INDEX: + sb.append("INDEX "); + if (randomTable.getIndexes().isEmpty()) { + throw new IgnoreMeException(); + } else { + sb.append(randomTable.getRandomIndex().getIndexName()); + } + break; + case COLUMN: + sb.append("COLUMN "); + sb.append(randomTable.getRandomColumn().getFullQualifiedName()); + break; + case STATISTICS: + sb.append("STATISTICS "); + if (randomTable.getStatistics().isEmpty()) { + throw new IgnoreMeException(); + } else { + sb.append(randomTable.getStatistics().get(0).getName()); + } + break; + case TABLE: + sb.append("TABLE "); + if (randomTable.isView()) { + throw new IgnoreMeException(); + } + sb.append(randomTable.getName()); + break; + default: + throw new AssertionError(type); + } + sb.append(" IS "); + if (Randomly.getBoolean()) { + sb.append("NULL"); + } else { + sb.append("'"); + sb.append(globalState.getRandomly().getString().replace("'", "''")); + sb.append("'"); + } + return new SQLQueryAdapter(sb.toString()); + } + + private enum Action { + INDEX, COLUMN, STATISTICS, TABLE + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLCommon.java b/src/sqlancer/yugabyte/ysql/gen/YSQLCommon.java new file mode 100644 index 000000000..231bf8008 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLCommon.java @@ -0,0 +1,283 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLProvider; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLConstant; + +public final class YSQLCommon { + + private YSQLCommon() { + } + + public static boolean appendDataType(YSQLDataType type, StringBuilder sb, boolean allowSerial, + boolean generateOnlyKnown, List opClasses) throws AssertionError { + boolean serial = false; + switch (type) { + case BOOLEAN: + sb.append("boolean"); + break; + case INT: + if (Randomly.getBoolean() && allowSerial) { + serial = true; + sb.append(Randomly.fromOptions("serial", "bigserial")); + } else { + sb.append(Randomly.fromOptions("smallint", "integer", "bigint")); + } + break; + case TEXT: + if (Randomly.getBoolean()) { + sb.append("TEXT"); + } else if (Randomly.getBoolean()) { + // TODO: support CHAR (without VAR) + if (YSQLProvider.generateOnlyKnown || Randomly.getBoolean()) { + sb.append("VAR"); + } + sb.append("CHAR"); + sb.append("("); + sb.append(ThreadLocalRandom.current().nextInt(1, 500)); + sb.append(")"); + } else { + sb.append("name"); + } + break; + case DECIMAL: + sb.append("DECIMAL"); + break; + case FLOAT: + case REAL: + if (Randomly.getBoolean()) { + sb.append("REAL"); + } else { + sb.append("FLOAT"); + } + break; + case RANGE: + sb.append(Randomly.fromOptions("int4range", "int4range")); // , "int8range", "numrange" + break; + case MONEY: + sb.append("money"); + break; + case BYTEA: + sb.append("bytea"); + break; + case BIT: + sb.append("BIT"); + // if (Randomly.getBoolean()) { + sb.append(" VARYING"); + // } + sb.append("("); + sb.append(Randomly.getNotCachedInteger(1, 500)); + sb.append(")"); + break; + case INET: + sb.append("inet"); + break; + default: + throw new AssertionError(type); + } + return serial; + } + + public static void generateWith(StringBuilder sb, YSQLGlobalState globalState, ExpectedErrors errors, + List columnsToBeAdded, boolean isTemporaryTable) { + if (Randomly.getBoolean()) { + sb.append(" WITHOUT OIDS "); + } else if (Randomly.getBoolean() && !isTemporaryTable) { + if (Randomly.getBoolean()) { + sb.append(" SPLIT INTO "); + sb.append(Randomly.smallNumber() + 1); + sb.append(" TABLETS "); + + errors.add("cannot create colocated table with split option"); + errors.add("columns must be present to split by number of tablets"); + errors.add("option is not yet supported for hash partitioned tables"); + } else { + sb.append(" SPLIT AT VALUES ("); + + errors.add("cannot create colocated table with split option"); + errors.add("SPLIT AT option is not yet supported for hash partitioned tables"); + errors.add("Cannot have duplicate split rows"); // just in case + + boolean hasBoolean = false; + for (YSQLColumn column : columnsToBeAdded) { + if (column.getType().equals(YSQLDataType.BOOLEAN)) { + hasBoolean = true; + break; + } + } + + int splits = hasBoolean ? 2 : Randomly.smallNumber() + 2; + long start = Randomly.smallNumber(); + boolean[] bools = { false, true }; + for (int i = 1; i <= splits; i++) { + int size = columnsToBeAdded.size(); + int counter = 1; + for (YSQLColumn c : columnsToBeAdded) { + sb.append("("); + switch (c.getType()) { + case INT: + case REAL: + sb.append(YSQLConstant.createDoubleConstant(start)); + case FLOAT: + sb.append(YSQLConstant.createIntConstant(start)); + break; + case BOOLEAN: + sb.append(YSQLConstant.createBooleanConstant(bools[i - 1])); + break; + case TEXT: + sb.append(YSQLConstant.createTextConstant(String.valueOf(start))); + break; + default: + throw new IgnoreMeException(); + } + sb.append(")"); + counter++; + start += Randomly.smallNumber() + 1; + if (counter <= size) { + sb.append(","); + } + } + + if (i < splits) { + sb.append(","); + } + } + sb.append(")"); + } + } else if (Randomly.getBoolean()) { + errors.add("Cannot use TABLEGROUP with TEMP table"); + if (!globalState.getSchema().getDatabaseIsColocated(globalState.getConnection())) { + sb.append(" TABLEGROUP tg").append( + Randomly.getNotCachedInteger(1, (int) YSQLTableGroupGenerator.UNIQUE_TABLEGROUP_COUNTER.get())); + } + } + } + + public static void addTableConstraints(boolean excludePrimaryKey, StringBuilder sb, YSQLTable table, + YSQLGlobalState globalState, ExpectedErrors errors) { + // TODO constraint name + List tableConstraints = Randomly.nonEmptySubset(TableConstraints.values()); + if (excludePrimaryKey) { + tableConstraints.remove(TableConstraints.PRIMARY_KEY); + } + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + tableConstraints.remove(TableConstraints.FOREIGN_KEY); + } + for (TableConstraints t : tableConstraints) { + sb.append(", "); + // TODO add index parameters + addTableConstraint(sb, table, globalState, t, errors); + } + } + + public static void addTableConstraint(StringBuilder sb, YSQLTable table, YSQLGlobalState globalState, + ExpectedErrors errors) { + addTableConstraint(sb, table, globalState, Randomly.fromOptions(TableConstraints.values()), errors); + } + + private static void addTableConstraint(StringBuilder sb, YSQLTable table, YSQLGlobalState globalState, + TableConstraints t, ExpectedErrors errors) { + List randomNonEmptyColumnSubset = table.getRandomNonEmptyColumnSubset(); + List otherColumns; + YSQLErrors.addCommonExpressionErrors(errors); + switch (t) { + case CHECK: + sb.append("CHECK("); + sb.append(YSQLVisitor.getExpressionAsString(globalState, YSQLDataType.BOOLEAN, table.getColumns())); + sb.append(")"); + errors.add("constraint must be added to child tables too"); + errors.add("missing FROM-clause entry for table"); + break; + case UNIQUE: + sb.append("UNIQUE("); + sb.append(randomNonEmptyColumnSubset.stream().map(AbstractTableColumn::getName) + .collect(Collectors.joining(", "))); + sb.append(")"); + break; + case PRIMARY_KEY: + sb.append("PRIMARY KEY("); + sb.append(randomNonEmptyColumnSubset.stream().map(AbstractTableColumn::getName) + .collect(Collectors.joining(", "))); + sb.append(")"); + break; + case FOREIGN_KEY: + sb.append("FOREIGN KEY ("); + sb.append(randomNonEmptyColumnSubset.stream().map(AbstractTableColumn::getName) + .collect(Collectors.joining(", "))); + sb.append(") REFERENCES "); + YSQLTable randomOtherTable = globalState.getSchema().getRandomTable(tab -> !tab.isView()); + sb.append(randomOtherTable.getName()); + if (randomOtherTable.getColumns().size() < randomNonEmptyColumnSubset.size()) { + throw new IgnoreMeException(); + } + otherColumns = randomOtherTable.getRandomNonEmptyColumnSubset(randomNonEmptyColumnSubset.size()); + sb.append("("); + sb.append(otherColumns.stream().map(AbstractTableColumn::getName).collect(Collectors.joining(", "))); + sb.append(")"); + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("MATCH FULL", "MATCH SIMPLE")); + } + if (Randomly.getBoolean()) { + sb.append(" ON DELETE "); + errors.add("ERROR: invalid ON DELETE action for foreign key constraint containing generated column"); + deleteOrUpdateAction(sb); + } + if (Randomly.getBoolean()) { + sb.append(" ON UPDATE "); + errors.add("invalid ON UPDATE action for foreign key constraint containing generated column"); + deleteOrUpdateAction(sb); + } + if (Randomly.getBoolean()) { + sb.append(" "); + if (Randomly.getBoolean()) { + sb.append("DEFERRABLE"); + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("INITIALLY DEFERRED", "INITIALLY IMMEDIATE")); + } + } else { + sb.append("NOT DEFERRABLE"); + } + } + break; + default: + throw new AssertionError(t); + } + } + + private static void deleteOrUpdateAction(StringBuilder sb) { + sb.append(Randomly.fromOptions("NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT")); + } + + public enum TableConstraints { + CHECK, UNIQUE, PRIMARY_KEY, FOREIGN_KEY + } + + // private enum StorageParameters { + // COLOCATED("COLOCATED", (r) -> Randomly.getBoolean()); + // // TODO + // + // private final String parameter; + // private final Function op; + // + // StorageParameters(String parameter, Function op) { + // this.parameter = parameter; + // this.op = op; + // } + // } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLDeleteGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLDeleteGenerator.java new file mode 100644 index 000000000..e0128707b --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLDeleteGenerator.java @@ -0,0 +1,47 @@ +package sqlancer.yugabyte.ysql.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.YSQLVisitor; + +public final class YSQLDeleteGenerator { + + private YSQLDeleteGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + YSQLTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + ExpectedErrors errors = new ExpectedErrors(); + errors.add("violates foreign key constraint"); + errors.add("violates not-null constraint"); + errors.add("could not determine which collation to use for string comparison"); + StringBuilder sb = new StringBuilder("DELETE FROM"); + if (Randomly.getBoolean()) { + sb.append(" ONLY"); + } + sb.append(" "); + sb.append(table.getName()); + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + sb.append(YSQLVisitor.asString( + YSQLExpressionGenerator.generateExpression(globalState, table.getColumns(), YSQLDataType.BOOLEAN))); + } + if (Randomly.getBoolean()) { + sb.append(" RETURNING "); + sb.append( + YSQLVisitor.asString(YSQLExpressionGenerator.generateExpression(globalState, table.getColumns()))); + } + YSQLErrors.addCommonExpressionErrors(errors); + errors.add("out of range"); + errors.add("cannot cast"); + errors.add("invalid input syntax for"); + errors.add("division by zero"); + return new SQLQueryAdapter(sb.toString(), errors); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLDiscardGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLDiscardGenerator.java new file mode 100644 index 000000000..156fb97da --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLDiscardGenerator.java @@ -0,0 +1,40 @@ +package sqlancer.yugabyte.ysql.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable.TableType; + +public final class YSQLDiscardGenerator { + + private YSQLDiscardGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + StringBuilder sb = new StringBuilder(); + sb.append("DISCARD "); + // prevent that DISCARD discards all tables (if they are TEMP tables) + boolean hasNonTempTables = globalState.getSchema().getDatabaseTables().stream() + .anyMatch(t -> t.getTableType() == TableType.STANDARD); + String what; + if (hasNonTempTables) { + what = Randomly.fromOptions("ALL", "PLANS", "SEQUENCES", "TEMPORARY", "TEMP"); + } else { + what = Randomly.fromOptions("PLANS", "SEQUENCES"); + } + sb.append(what); + return new SQLQueryAdapter(sb.toString(), ExpectedErrors.from("cannot run inside a transaction block")) { + private static final long serialVersionUID = 1L; + + @Override + public boolean couldAffectSchema() { + return canDiscardTemporaryTables(what); + } + }; + } + + private static boolean canDiscardTemporaryTables(String what) { + return what.contentEquals("TEMPORARY") || what.contentEquals("TEMP") || what.contentEquals("ALL"); + } +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLDropIndexGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLDropIndexGenerator.java new file mode 100644 index 000000000..da9b5ae61 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLDropIndexGenerator.java @@ -0,0 +1,41 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLIndex; + +public final class YSQLDropIndexGenerator { + + private YSQLDropIndexGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + List indexes = globalState.getSchema().getRandomTable().getIndexes(); + StringBuilder sb = new StringBuilder(); + sb.append("DROP INDEX "); + if (Randomly.getBoolean() || indexes.isEmpty()) { + sb.append("IF EXISTS "); + if (indexes.isEmpty() || Randomly.getBoolean()) { + sb.append(DBMSCommon.createIndexName(Randomly.smallNumber())); + } else { + sb.append(Randomly.fromList(indexes).getIndexName()); + } + } else { + sb.append(Randomly.fromList(indexes).getIndexName()); + } + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("CASCADE", "RESTRICT")); + } + return new SQLQueryAdapter(sb.toString(), + ExpectedErrors.from("cannot drop desired object(s) because other objects depend on them", + "cannot drop index", "does not exist"), + true); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLExpressionGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLExpressionGenerator.java new file mode 100644 index 000000000..80c35c431 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLExpressionGenerator.java @@ -0,0 +1,689 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.gen.NoRECGenerator; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.schema.AbstractTables; +import sqlancer.yugabyte.ysql.YSQLCompoundDataType; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLProvider; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLRowValue; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTables; +import sqlancer.yugabyte.ysql.ast.YSQLAggregate; +import sqlancer.yugabyte.ysql.ast.YSQLAggregate.YSQLAggregateFunction; +import sqlancer.yugabyte.ysql.ast.YSQLBetweenOperation; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryArithmeticOperation; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryBitOperation; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryComparisonOperation; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryLogicalOperation; +import sqlancer.yugabyte.ysql.ast.YSQLBinaryRangeOperation; +import sqlancer.yugabyte.ysql.ast.YSQLCastOperation; +import sqlancer.yugabyte.ysql.ast.YSQLColumnValue; +import sqlancer.yugabyte.ysql.ast.YSQLConcatOperation; +import sqlancer.yugabyte.ysql.ast.YSQLConstant; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; +import sqlancer.yugabyte.ysql.ast.YSQLFunction; +import sqlancer.yugabyte.ysql.ast.YSQLFunctionWithUnknownResult; +import sqlancer.yugabyte.ysql.ast.YSQLInOperation; +import sqlancer.yugabyte.ysql.ast.YSQLJoin; +import sqlancer.yugabyte.ysql.ast.YSQLOrderByTerm; +import sqlancer.yugabyte.ysql.ast.YSQLPOSIXRegularExpression; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixText; +import sqlancer.yugabyte.ysql.ast.YSQLPrefixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; +import sqlancer.yugabyte.ysql.ast.YSQLSimilarTo; + +public class YSQLExpressionGenerator implements ExpressionGenerator, + NoRECGenerator, + TLPWhereGenerator { + + private final int maxDepth; + + private final Randomly r; + private final Map functionsAndTypes; + private final List allowedFunctionTypes; + private List columns; + private List tables; + private YSQLRowValue rw; + private boolean expectedResult; + private YSQLGlobalState globalState; + private boolean allowAggregateFunctions; + + public YSQLExpressionGenerator(YSQLGlobalState globalState) { + this.r = globalState.getRandomly(); + this.maxDepth = globalState.getOptions().getMaxExpressionDepth(); + this.globalState = globalState; + this.functionsAndTypes = globalState.getFunctionsAndTypes(); + this.allowedFunctionTypes = globalState.getAllowedFunctionTypes(); + } + + public static YSQLExpression generateExpression(YSQLGlobalState globalState, YSQLDataType type) { + return new YSQLExpressionGenerator(globalState).generateExpression(0, type); + } + + private static YSQLCompoundDataType getCompoundDataType(YSQLDataType type) { + switch (type) { + case BOOLEAN: + case DECIMAL: // TODO + case FLOAT: + case INT: + case MONEY: + case RANGE: + case REAL: + case INET: + case BYTEA: + return YSQLCompoundDataType.create(type); + case TEXT: // TODO + case BIT: + if (Randomly.getBoolean() + || YSQLProvider.generateOnlyKnown /* + * The PQS implementation does not check for size specifications + */) { + return YSQLCompoundDataType.create(type); + } else { + return YSQLCompoundDataType.create(type, (int) Randomly.getNotCachedInteger(1, 1000)); + } + default: + throw new AssertionError(type); + } + + } + + public static YSQLExpression generateConstant(Randomly r, YSQLDataType type) { + if (Randomly.getBooleanWithSmallProbability()) { + return YSQLConstant.createNullConstant(); + } + // if (Randomly.getBooleanWithSmallProbability()) { + // return YSQLConstant.createTextConstant(r.getString()); + // } + switch (type) { + case INT: + if (Randomly.getBooleanWithSmallProbability()) { + return YSQLConstant.createTextConstant(String.valueOf(r.getInteger())); + } else { + return YSQLConstant.createIntConstant(r.getInteger()); + } + case BOOLEAN: + if (Randomly.getBooleanWithSmallProbability() && !YSQLProvider.generateOnlyKnown) { + return YSQLConstant + .createTextConstant(Randomly.fromOptions("TR", "TRUE", "FA", "FALSE", "0", "1", "ON", "off")); + } else { + return YSQLConstant.createBooleanConstant(Randomly.getBoolean()); + } + case TEXT: + return YSQLConstant.createTextConstant(r.getString()); + case DECIMAL: + return YSQLConstant.createDecimalConstant(r.getRandomBigDecimal()); + case FLOAT: + return YSQLConstant.createFloatConstant((float) r.getDouble()); + case REAL: + return YSQLConstant.createDoubleConstant(r.getDouble()); + case RANGE: + return YSQLConstant.createRange(r.getInteger(), Randomly.getBoolean(), r.getInteger(), + Randomly.getBoolean()); + case MONEY: + return new YSQLCastOperation(generateConstant(r, YSQLDataType.FLOAT), + getCompoundDataType(YSQLDataType.MONEY)); + case INET: + return YSQLConstant.createInetConstant(getRandomInet(r)); + case BIT: + return YSQLConstant.createBitConstant(r.getInteger()); + case BYTEA: + return YSQLConstant.createByteConstant(String.valueOf(r.getInteger())); + default: + throw new AssertionError(type); + } + } + + private static String getRandomInet(Randomly r) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 4; i++) { + if (i != 0) { + sb.append('.'); + } + sb.append(r.getInteger() & 255); + } + return sb.toString(); + } + + public static YSQLExpression generateExpression(YSQLGlobalState globalState, List columns, + YSQLDataType type) { + return new YSQLExpressionGenerator(globalState).setColumns(columns).generateExpression(0, type); + } + + public static YSQLExpression generateExpression(YSQLGlobalState globalState, List columns) { + return new YSQLExpressionGenerator(globalState).setColumns(columns).generateExpression(0); + + } + + public YSQLExpressionGenerator setColumns(List columns) { + this.columns = columns; + return this; + } + + public YSQLExpressionGenerator setRowValue(YSQLRowValue rw) { + this.rw = rw; + return this; + } + + public YSQLExpression generateExpression(int depth) { + return generateExpression(depth, YSQLDataType.getRandomType()); + } + + @Override + public List generateOrderBys() { + List orderBys = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber(); i++) { + orderBys.add(new YSQLOrderByTerm(YSQLColumnValue.create(Randomly.fromList(columns), null), + YSQLOrderByTerm.YSQLOrder.getRandomOrder())); + } + return orderBys; + } + + private YSQLExpression generateFunctionWithUnknownResult(int depth, YSQLDataType type) { + List supportedFunctions = YSQLFunctionWithUnknownResult + .getSupportedFunctions(type); + // filters functions by allowed type (STABLE 's', IMMUTABLE 'i', VOLATILE 'v') + supportedFunctions = supportedFunctions.stream() + .filter(f -> allowedFunctionTypes.contains(functionsAndTypes.get(f.getName()))) + .collect(Collectors.toList()); + if (supportedFunctions.isEmpty()) { + throw new IgnoreMeException(); + } + YSQLFunctionWithUnknownResult randomFunction = Randomly.fromList(supportedFunctions); + return new YSQLFunction(randomFunction, type, randomFunction.getArguments(type, this, depth + 1)); + } + + private YSQLExpression generateFunctionWithKnownResult(int depth, YSQLDataType type) { + List functions = Stream.of(YSQLFunction.YSQLFunctionWithResult.values()) + .filter(f -> f.supportsReturnType(type)).collect(Collectors.toList()); + // filters functions by allowed type (STABLE 's', IMMUTABLE 'i', VOLATILE 'v') + functions = functions.stream().filter(f -> allowedFunctionTypes.contains(functionsAndTypes.get(f.getName()))) + .collect(Collectors.toList()); + if (functions.isEmpty()) { + throw new IgnoreMeException(); + } + YSQLFunction.YSQLFunctionWithResult randomFunction = Randomly.fromList(functions); + int nrArgs = randomFunction.getNrArgs(); + if (randomFunction.isVariadic()) { + nrArgs += Randomly.smallNumber(); + } + YSQLDataType[] argTypes = randomFunction.getInputTypesForReturnType(type, nrArgs); + YSQLExpression[] args = new YSQLExpression[nrArgs]; + do { + for (int i = 0; i < args.length; i++) { + args[i] = generateExpression(depth + 1, argTypes[i]); + } + } while (!randomFunction.checkArguments(args)); + return new YSQLFunction(randomFunction, type, args); + } + + private YSQLExpression generateBooleanExpression(int depth) { + List validOptions = new ArrayList<>(Arrays.asList(BooleanExpression.values())); + if (YSQLProvider.generateOnlyKnown) { + validOptions.remove(BooleanExpression.SIMILAR_TO); + validOptions.remove(BooleanExpression.POSIX_REGEX); + validOptions.remove(BooleanExpression.BINARY_RANGE_COMPARISON); + } + BooleanExpression option = Randomly.fromList(validOptions); + switch (option) { + case POSTFIX_OPERATOR: + YSQLPostfixOperation.PostfixOperator random = YSQLPostfixOperation.PostfixOperator.getRandom(); + return YSQLPostfixOperation + .create(generateExpression(depth + 1, Randomly.fromOptions(random.getInputDataTypes())), random); + case IN_OPERATION: + return inOperation(depth + 1); + case NOT: + return new YSQLPrefixOperation(generateExpression(depth + 1, YSQLDataType.BOOLEAN), + YSQLPrefixOperation.PrefixOperator.NOT); + case BINARY_LOGICAL_OPERATOR: + YSQLExpression first = generateExpression(depth + 1, YSQLDataType.BOOLEAN); + int nr = Randomly.smallNumber() + 1; + for (int i = 0; i < nr; i++) { + first = new YSQLBinaryLogicalOperation(first, generateExpression(depth + 1, YSQLDataType.BOOLEAN), + YSQLBinaryLogicalOperation.BinaryLogicalOperator.getRandom()); + } + return first; + case BINARY_COMPARISON: + YSQLDataType dataType = getMeaningfulType(); + return generateComparison(depth, dataType); + case CAST: + return new YSQLCastOperation(generateExpression(depth + 1), getCompoundDataType(YSQLDataType.BOOLEAN)); + case FUNCTION: + return generateFunction(depth + 1, YSQLDataType.BOOLEAN); + case BETWEEN: + YSQLDataType type = getMeaningfulType(); + return new YSQLBetweenOperation(generateExpression(depth + 1, type), generateExpression(depth + 1, type), + generateExpression(depth + 1, type), Randomly.getBoolean()); + case SIMILAR_TO: + assert !expectedResult; + // TODO also generate the escape character + return new YSQLSimilarTo(generateExpression(depth + 1, YSQLDataType.TEXT), + generateExpression(depth + 1, YSQLDataType.TEXT), null); + case POSIX_REGEX: + assert !expectedResult; + return new YSQLPOSIXRegularExpression(generateExpression(depth + 1, YSQLDataType.TEXT), + generateExpression(depth + 1, YSQLDataType.TEXT), + YSQLPOSIXRegularExpression.POSIXRegex.getRandom()); + case BINARY_RANGE_COMPARISON: + // TODO element check + return new YSQLBinaryRangeOperation(YSQLBinaryRangeOperation.YSQLBinaryRangeComparisonOperator.getRandom(), + generateExpression(depth + 1, YSQLDataType.RANGE), + generateExpression(depth + 1, YSQLDataType.RANGE)); + default: + throw new AssertionError(); + } + } + + private YSQLDataType getMeaningfulType() { + // make it more likely that the expression does not only consist of constant + // expressions + if (Randomly.getBooleanWithSmallProbability() || columns == null || columns.isEmpty()) { + return YSQLDataType.getRandomType(); + } else { + return Randomly.fromList(columns).getType(); + } + } + + private YSQLExpression generateFunction(int depth, YSQLDataType type) { + if (YSQLProvider.generateOnlyKnown || Randomly.getBoolean()) { + return generateFunctionWithKnownResult(depth, type); + } else { + return generateFunctionWithUnknownResult(depth, type); + } + } + + private YSQLExpression generateComparison(int depth, YSQLDataType dataType) { + YSQLExpression leftExpr = generateExpression(depth + 1, dataType); + YSQLExpression rightExpr = generateExpression(depth + 1, dataType); + return getComparison(leftExpr, rightExpr); + } + + private YSQLExpression getComparison(YSQLExpression leftExpr, YSQLExpression rightExpr) { + return new YSQLBinaryComparisonOperation(leftExpr, rightExpr, + YSQLBinaryComparisonOperation.YSQLBinaryComparisonOperator.getRandom()); + } + + private YSQLExpression inOperation(int depth) { + YSQLDataType type = YSQLDataType.getRandomType(); + YSQLExpression leftExpr = generateExpression(depth + 1, type); + List rightExpr = new ArrayList<>(); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + rightExpr.add(generateExpression(depth + 1, type)); + } + return new YSQLInOperation(leftExpr, rightExpr, Randomly.getBoolean()); + } + + public YSQLExpression generateExpression(int depth, YSQLDataType originalType) { + YSQLDataType dataType = originalType; + if (dataType == YSQLDataType.REAL && Randomly.getBoolean()) { + dataType = Randomly.fromOptions(YSQLDataType.INT, YSQLDataType.FLOAT); + } + if (dataType == YSQLDataType.FLOAT && Randomly.getBoolean()) { + dataType = YSQLDataType.INT; + } + return generateExpressionInternal(depth, dataType); + } + + private YSQLExpression generateExpressionInternal(int depth, YSQLDataType dataType) throws AssertionError { + if (allowAggregateFunctions && Randomly.getBoolean()) { + allowAggregateFunctions = false; // aggregate function calls cannot be nested + return getAggregate(dataType); + } + if (Randomly.getBooleanWithRatherLowProbability() || depth > maxDepth) { + // generic expression + if (Randomly.getBoolean() || depth > maxDepth) { + if (Randomly.getBooleanWithRatherLowProbability()) { + return generateConstant(r, dataType); + } else { + if (filterColumns(dataType).isEmpty()) { + return generateConstant(r, dataType); + } else { + return createColumnOfType(dataType); + } + } + } else { + if (Randomly.getBoolean()) { + return new YSQLCastOperation(generateExpression(depth + 1), getCompoundDataType(dataType)); + } else { + return generateFunctionWithUnknownResult(depth, dataType); + } + } + } else { + switch (dataType) { + case BOOLEAN: + return generateBooleanExpression(depth); + case INT: + return generateIntExpression(depth); + case TEXT: + return generateTextExpression(depth); + case DECIMAL: + case REAL: + case FLOAT: + case MONEY: + case INET: + return generateConstant(r, dataType); + case BYTEA: + return generateByteExpression(); + case BIT: + return generateBitExpression(depth); + case RANGE: + return generateRangeExpression(depth); + default: + throw new AssertionError(dataType); + } + } + } + + private YSQLExpression generateRangeExpression(int depth) { + RangeExpression option; + List validOptions = new ArrayList<>(Arrays.asList(RangeExpression.values())); + option = Randomly.fromList(validOptions); + switch (option) { + case BINARY_OP: + return new YSQLBinaryRangeOperation(YSQLBinaryRangeOperation.YSQLBinaryRangeOperator.getRandom(), + generateExpression(depth + 1, YSQLDataType.RANGE), + generateExpression(depth + 1, YSQLDataType.RANGE)); + default: + throw new AssertionError(option); + } + } + + private YSQLExpression generateTextExpression(int depth) { + TextExpression option; + List validOptions = new ArrayList<>(Arrays.asList(TextExpression.values())); + option = Randomly.fromList(validOptions); + + switch (option) { + case CAST: + return new YSQLCastOperation(generateExpression(depth + 1), getCompoundDataType(YSQLDataType.TEXT)); + case FUNCTION: + return generateFunction(depth + 1, YSQLDataType.TEXT); + case CONCAT: + return generateConcat(depth); + default: + throw new AssertionError(); + } + } + + private YSQLExpression generateConcat(int depth) { + YSQLExpression left = generateExpression(depth + 1, YSQLDataType.TEXT); + YSQLExpression right = generateExpression(depth + 1); + return new YSQLConcatOperation(left, right); + } + + private YSQLExpression generateByteExpression() { + return YSQLConstant.createByteConstant("Th\\000omas"); + } + + private YSQLExpression generateBitExpression(int depth) { + BitExpression option; + option = Randomly.fromOptions(BitExpression.values()); + switch (option) { + case BINARY_OPERATION: + return new YSQLBinaryBitOperation(YSQLBinaryBitOperation.YSQLBinaryBitOperator.getRandom(), + generateExpression(depth + 1, YSQLDataType.BIT), generateExpression(depth + 1, YSQLDataType.BIT)); + default: + throw new AssertionError(); + } + } + + private YSQLExpression generateIntExpression(int depth) { + IntExpression option; + option = Randomly.fromOptions(IntExpression.values()); + switch (option) { + case CAST: + return new YSQLCastOperation(generateExpression(depth + 1), getCompoundDataType(YSQLDataType.INT)); + case UNARY_OPERATION: + YSQLExpression intExpression = generateExpression(depth + 1, YSQLDataType.INT); + return new YSQLPrefixOperation(intExpression, Randomly.getBoolean() + ? YSQLPrefixOperation.PrefixOperator.UNARY_PLUS : YSQLPrefixOperation.PrefixOperator.UNARY_MINUS); + case FUNCTION: + return generateFunction(depth + 1, YSQLDataType.INT); + case BINARY_ARITHMETIC_EXPRESSION: + return new YSQLBinaryArithmeticOperation(generateExpression(depth + 1, YSQLDataType.INT), + generateExpression(depth + 1, YSQLDataType.INT), + YSQLBinaryArithmeticOperation.YSQLBinaryOperator.getRandom()); + default: + throw new AssertionError(); + } + } + + private YSQLExpression createColumnOfType(YSQLDataType type) { + List columns = filterColumns(type); + YSQLColumn fromList = Randomly.fromList(columns); + YSQLConstant value = rw == null ? null : rw.getValues().get(fromList); + return YSQLColumnValue.create(fromList, value); + } + + final List filterColumns(YSQLDataType type) { + if (columns == null) { + return Collections.emptyList(); + } else { + return columns.stream().filter(c -> c.getType() == type).collect(Collectors.toList()); + } + } + + public YSQLExpression generateExpressionWithExpectedResult(YSQLDataType type) { + this.expectedResult = true; + YSQLExpressionGenerator gen = new YSQLExpressionGenerator(globalState).setColumns(columns).setRowValue(rw); + YSQLExpression expr; + do { + expr = gen.generateExpression(type); + } while (expr.getExpectedValue() == null); + return expr; + } + + public List generateExpressions(int nr) { + List expressions = new ArrayList<>(); + for (int i = 0; i < nr; i++) { + expressions.add(generateExpression(0)); + } + return expressions; + } + + public YSQLExpression generateExpression(YSQLDataType dataType) { + return generateExpression(0, dataType); + } + + public YSQLExpressionGenerator setGlobalState(YSQLGlobalState globalState) { + this.globalState = globalState; + return this; + } + + public YSQLExpression generateHavingClause() { + this.allowAggregateFunctions = true; + YSQLExpression expression = generateExpression(YSQLDataType.BOOLEAN); + this.allowAggregateFunctions = false; + return expression; + } + + public YSQLExpression generateAggregate() { + return getAggregate(YSQLDataType.getRandomType()); + } + + private YSQLExpression getAggregate(YSQLDataType dataType) { + List aggregates = YSQLAggregate.YSQLAggregateFunction + .getAggregates(dataType); + YSQLAggregate.YSQLAggregateFunction agg = Randomly.fromList(aggregates); + return generateArgsForAggregate(dataType, agg); + } + + public YSQLAggregate generateArgsForAggregate(YSQLDataType dataType, YSQLAggregate.YSQLAggregateFunction agg) { + List types = agg.getTypes(dataType); + List args = new ArrayList<>(); + for (YSQLDataType argType : types) { + args.add(generateExpression(argType)); + } + return new YSQLAggregate(args, agg); + } + + public YSQLExpressionGenerator allowAggregates(boolean value) { + allowAggregateFunctions = value; + return this; + } + + @Override + public YSQLExpression generatePredicate() { + return generateExpression(YSQLDataType.BOOLEAN); + } + + @Override + public YSQLExpression negatePredicate(YSQLExpression predicate) { + return new YSQLPrefixOperation(predicate, YSQLPrefixOperation.PrefixOperator.NOT); + } + + @Override + public YSQLExpression isNull(YSQLExpression expr) { + return new YSQLPostfixOperation(expr, YSQLPostfixOperation.PostfixOperator.IS_NULL); + } + + private enum BooleanExpression { + POSTFIX_OPERATOR, NOT, BINARY_LOGICAL_OPERATOR, BINARY_COMPARISON, FUNCTION, CAST, BETWEEN, IN_OPERATION, + SIMILAR_TO, POSIX_REGEX, BINARY_RANGE_COMPARISON + } + + private enum RangeExpression { + BINARY_OP + } + + private enum TextExpression { + CAST, FUNCTION, CONCAT + } + + private enum BitExpression { + BINARY_OPERATION + } + + private enum IntExpression { + UNARY_OPERATION, FUNCTION, CAST, BINARY_ARITHMETIC_EXPRESSION + } + + public static YSQLSelect.YSQLSubquery createSubquery(YSQLGlobalState globalState, String name, YSQLTables tables) { + List columns = new ArrayList<>(); + YSQLExpressionGenerator gen = new YSQLExpressionGenerator(globalState).setColumns(tables.getColumns()); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + columns.add(gen.generateExpression(0)); + } + YSQLSelect select = new YSQLSelect(); + select.setFromList(tables.getTables().stream().map(t -> new YSQLSelect.YSQLFromTable(t, Randomly.getBoolean())) + .collect(Collectors.toList())); + select.setFetchColumns(columns); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(0, YSQLDataType.BOOLEAN)); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + if (Randomly.getBoolean()) { + select.setLimitClause(YSQLConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + if (Randomly.getBoolean()) { + select.setOffsetClause(YSQLConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + } + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setForClause(YSQLSelect.ForClause.getRandom()); + } + return new YSQLSelect.YSQLSubquery(select, name); + } + + @Override + public YSQLExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + + return this; + } + + @Override + public YSQLExpression generateBooleanExpression() { + return generateExpression(YSQLDataType.BOOLEAN); + } + + @Override + public YSQLSelect generateSelect() { + return new YSQLSelect(); + } + + @Override + public List getRandomJoinClauses() { + List joinStatements = new ArrayList<>(); + YSQLExpressionGenerator gen = new YSQLExpressionGenerator(globalState).setColumns(columns); + for (int i = 1; i < tables.size(); i++) { + YSQLExpression joinClause = gen.generateExpression(YSQLDataType.BOOLEAN); + YSQLTable table = Randomly.fromList(tables); + tables.remove(table); + YSQLJoin.YSQLJoinType options = YSQLJoin.YSQLJoinType.getRandom(); + YSQLJoin j = new YSQLJoin(new YSQLSelect.YSQLFromTable(table, Randomly.getBoolean()), joinClause, options); + joinStatements.add(j); + } + // JOIN subqueries + for (int i = 0; i < Randomly.smallNumber(); i++) { + YSQLTables subqueryTables = globalState.getSchema().getRandomTableNonEmptyTables(); + YSQLSelect.YSQLSubquery subquery = createSubquery(globalState, String.format("sub%d", i), subqueryTables); + YSQLExpression joinClause = gen.generateExpression(YSQLDataType.BOOLEAN); + YSQLJoin.YSQLJoinType options = YSQLJoin.YSQLJoinType.getRandom(); + YSQLJoin j = new YSQLJoin(subquery, joinClause, options); + joinStatements.add(j); + } + return joinStatements; + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new YSQLSelect.YSQLFromTable(t, Randomly.getBoolean())) + .collect(Collectors.toList()); + } + + @Override + public String generateOptimizedQueryString(YSQLSelect select, YSQLExpression whereCondition, + boolean shouldUseAggregate) { + if (shouldUseAggregate) { + YSQLAggregate aggr = new YSQLAggregate(List.of(new YSQLColumnValue(YSQLColumn.createDummy("*"), null)), + YSQLAggregateFunction.COUNT); + select.setFetchColumns(List.of(aggr)); + } else { + YSQLColumnValue allColumns = new YSQLColumnValue(Randomly.fromList(columns), null); + select.setFetchColumns(Arrays.asList(allColumns)); + if (Randomly.getBooleanWithSmallProbability()) { + select.setOrderByClauses(generateOrderBys()); + } + select.setWhereClause(whereCondition); + } + + return select.asString(); + } + + @Override + public String generateUnoptimizedQueryString(YSQLSelect select, YSQLExpression whereCondition) { + YSQLCastOperation isTrue = new YSQLCastOperation(whereCondition, YSQLCompoundDataType.create(YSQLDataType.INT)); + YSQLPostfixText asText = new YSQLPostfixText(isTrue, " as count", null, YSQLDataType.INT); + select.setFetchColumns(Collections.singletonList(asText)); + select.setSelectType(YSQLSelect.SelectType.ALL); + select.setWhereClause(null); + + return "SELECT SUM(count) FROM (" + select.asString() + ") as res"; + } + + @Override + public List generateFetchColumns(boolean shouldCreateDummy) { + if (shouldCreateDummy && Randomly.getBooleanWithRatherLowProbability()) { + return List.of(new YSQLColumnValue(YSQLColumn.createDummy("*"), null)); + } + return columns.stream().map(c -> new YSQLColumnValue(c, null)).collect(Collectors.toList()); + } +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLIndexGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLIndexGenerator.java new file mode 100644 index 000000000..6077dcb1e --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLIndexGenerator.java @@ -0,0 +1,143 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLIndex; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; + +public final class YSQLIndexGenerator { + + private YSQLIndexGenerator() { + } + + public static SQLQueryAdapter generate(YSQLGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder(); + sb.append("CREATE"); + if (Randomly.getBoolean()) { + sb.append(" UNIQUE"); + } + sb.append(" INDEX "); + YSQLTable randomTable = globalState.getSchema().getRandomTable(t -> !t.isView()); // TODO: materialized + // views + String indexName = getNewIndexName(randomTable); + sb.append(indexName); + sb.append(" ON "); + if (Randomly.getBoolean()) { + sb.append("ONLY "); + } + sb.append(randomTable.getName()); + IndexType method; + if (Randomly.getBoolean()) { + sb.append(" USING "); + method = Randomly.fromOptions(IndexType.values()); + sb.append(method); + } else { + method = IndexType.BTREE; + } + + sb.append("("); + if (method == IndexType.HASH) { + sb.append(randomTable.getRandomColumn().getName()); + } else { + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + if (Randomly.getBoolean()) { + sb.append(randomTable.getRandomColumn().getName()); + } else { + sb.append("("); + YSQLExpression expression = YSQLExpressionGenerator.generateExpression(globalState, + randomTable.getColumns()); + sb.append(YSQLVisitor.asString(expression)); + sb.append(")"); + } + + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" "); + sb.append(globalState.getRandomOpclass()); + errors.add("does not accept"); + errors.add("does not exist for access method"); + } + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("ASC", "DESC")); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" NULLS "); + sb.append(Randomly.fromOptions("FIRST", "LAST")); + } + } + } + + sb.append(")"); + if (Randomly.getBoolean() && method != IndexType.HASH) { + sb.append(" INCLUDE("); + List columns = randomTable.getRandomNonEmptyColumnSubset(); + sb.append(columns.stream().map(AbstractTableColumn::getName).collect(Collectors.joining(", "))); + sb.append(")"); + } + if (Randomly.getBoolean()) { + sb.append(" WHERE "); + YSQLExpression expr = new YSQLExpressionGenerator(globalState).setColumns(randomTable.getColumns()) + .setGlobalState(globalState).generateExpression(YSQLDataType.BOOLEAN); + sb.append(YSQLVisitor.asString(expr)); + } + errors.add("already contains data"); // CONCURRENT INDEX failed + errors.add("You might need to add explicit type casts"); + errors.add("INDEX on column of type"); + errors.add("collations are not supported"); // TODO check + errors.add("because it has pending trigger events"); + errors.add("duplicate key value violates unique constraint"); + errors.add("could not determine which collation to use for"); + errors.add("index method \"gist\" not supported yet"); + errors.add("is duplicated"); + errors.add("already exists"); + errors.add("could not create unique index"); + errors.add("has no default operator class"); + errors.add("does not support"); + errors.add("cannot cast"); + errors.add("unsupported UNIQUE constraint with partition key definition"); + errors.add("insufficient columns in UNIQUE constraint definition"); + errors.add("invalid input syntax for"); + errors.add("must be type "); + errors.add("integer out of range"); + errors.add("division by zero"); + errors.add("out of range"); + errors.add("functions in index predicate must be marked IMMUTABLE"); + errors.add("functions in index expression must be marked IMMUTABLE"); + errors.add("result of range difference would not be contiguous"); + errors.add("which is part of the partition key"); + YSQLErrors.addCommonExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors); + } + + private static String getNewIndexName(YSQLTable randomTable) { + List indexes = randomTable.getIndexes(); + int indexI = 0; + while (true) { + String indexName = DBMSCommon.createIndexName(indexI++); + if (indexes.stream().noneMatch(i -> i.getIndexName().equals(indexName))) { + return indexName; + } + } + } + + public enum IndexType { + BTREE, HASH, GIST, GIN + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLInsertGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLInsertGenerator.java new file mode 100644 index 000000000..c26459f31 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLInsertGenerator.java @@ -0,0 +1,128 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; + +public final class YSQLInsertGenerator { + + private YSQLInsertGenerator() { + } + + public static SQLQueryAdapter insert(YSQLGlobalState globalState) { + YSQLTable table = globalState.getSchema().getRandomTable(YSQLTable::isInsertable); + ExpectedErrors errors = new ExpectedErrors(); + errors.add("cannot insert into column"); + YSQLErrors.addCommonExpressionErrors(errors); + YSQLErrors.addCommonInsertUpdateErrors(errors); + YSQLErrors.addCommonExpressionErrors(errors); + errors.add("multiple assignments to same column"); + errors.add("violates foreign key constraint"); + errors.add("value too long for type character varying"); + errors.add("conflicting key value violates exclusion constraint"); + errors.add("violates not-null constraint"); + errors.add("current transaction is aborted"); + errors.add("bit string too long"); + errors.add("new row violates check option for view"); + errors.add("reached maximum value of sequence"); + errors.add("but expression is of type"); + StringBuilder sb = new StringBuilder(); + sb.append("INSERT INTO "); + sb.append(table.getName()); + List columns = table.getRandomNonEmptyColumnSubset(); + sb.append("("); + sb.append(columns.stream().map(AbstractTableColumn::getName).collect(Collectors.joining(", "))); + sb.append(")"); + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" OVERRIDING"); + sb.append(" "); + sb.append(Randomly.fromOptions("SYSTEM", "USER")); + sb.append(" VALUE"); + } + sb.append(" VALUES"); + + if (globalState.getDbmsSpecificOptions().allowBulkInsert && Randomly.getBooleanWithSmallProbability()) { + StringBuilder sbRowValue = new StringBuilder(); + sbRowValue.append("("); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sbRowValue.append(", "); + } + sbRowValue.append(YSQLVisitor.asString( + YSQLExpressionGenerator.generateConstant(globalState.getRandomly(), columns.get(i).getType()))); + } + sbRowValue.append(")"); + + int n = (int) Randomly.getNotCachedInteger(100, 1000); + for (int i = 0; i < n; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(sbRowValue); + } + } else { + int n = Randomly.smallNumber() + 1; + for (int i = 0; i < n; i++) { + if (i != 0) { + sb.append(", "); + } + insertRow(globalState, sb, columns, n == 1); + } + } + if (Randomly.getBooleanWithRatherLowProbability()) { + sb.append(" ON CONFLICT "); + if (Randomly.getBoolean()) { + sb.append("("); + sb.append(table.getRandomColumn().getName()); + sb.append(")"); + errors.add("there is no unique or exclusion constraint matching the ON CONFLICT specification"); + } + sb.append(" DO NOTHING"); + } + errors.add("duplicate key value violates unique constraint"); + errors.add("identity column defined as GENERATED ALWAYS"); + errors.add("out of range"); + errors.add("violates check constraint"); + errors.add("no partition of relation"); + errors.add("invalid input syntax"); + errors.add("division by zero"); + errors.add("violates foreign key constraint"); + errors.add("data type unknown"); + return new SQLQueryAdapter(sb.toString(), errors); + } + + private static void insertRow(YSQLGlobalState globalState, StringBuilder sb, List columns, + boolean canBeDefault) { + sb.append("("); + for (int i = 0; i < columns.size(); i++) { + if (i != 0) { + sb.append(", "); + } + if (!Randomly.getBooleanWithSmallProbability() || !canBeDefault) { + YSQLExpression generateConstant; + if (Randomly.getBoolean()) { + generateConstant = YSQLExpressionGenerator.generateConstant(globalState.getRandomly(), + columns.get(i).getType()); + } else { + generateConstant = new YSQLExpressionGenerator(globalState) + .generateExpression(columns.get(i).getType()); + } + sb.append(YSQLVisitor.asString(generateConstant)); + } else { + sb.append("DEFAULT"); + } + } + sb.append(")"); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLNotifyGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLNotifyGenerator.java new file mode 100644 index 000000000..a6fb58d8e --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLNotifyGenerator.java @@ -0,0 +1,45 @@ +package sqlancer.yugabyte.ysql.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; + +public final class YSQLNotifyGenerator { + + private YSQLNotifyGenerator() { + } + + private static String getChannel() { + return Randomly.fromOptions("asdf", "test"); + } + + public static SQLQueryAdapter createNotify(YSQLGlobalState globalState) { + StringBuilder sb = new StringBuilder(); + sb.append("NOTIFY "); + sb.append(getChannel()); + if (Randomly.getBoolean()) { + sb.append(", "); + sb.append("'"); + sb.append(globalState.getRandomly().getString().replace("'", "''")); + sb.append("'"); + } + return new SQLQueryAdapter(sb.toString()); + } + + public static SQLQueryAdapter createListen() { + String sb = "LISTEN " + getChannel(); + return new SQLQueryAdapter(sb); + } + + public static SQLQueryAdapter createUnlisten() { + StringBuilder sb = new StringBuilder(); + sb.append("UNLISTEN "); + if (Randomly.getBoolean()) { + sb.append(getChannel()); + } else { + sb.append("*"); + } + return new SQLQueryAdapter(sb.toString()); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLRandomQueryGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLRandomQueryGenerator.java new file mode 100644 index 000000000..9821dcb7d --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLRandomQueryGenerator.java @@ -0,0 +1,62 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTables; +import sqlancer.yugabyte.ysql.ast.YSQLConstant; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; +import sqlancer.yugabyte.ysql.ast.YSQLSelect.ForClause; +import sqlancer.yugabyte.ysql.ast.YSQLSelect.SelectType; +import sqlancer.yugabyte.ysql.ast.YSQLSelect.YSQLFromTable; + +public final class YSQLRandomQueryGenerator { + + private YSQLRandomQueryGenerator() { + } + + public static YSQLSelect createRandomQuery(int nrColumns, YSQLGlobalState globalState) { + List columns = new ArrayList<>(); + YSQLTables tables = globalState.getSchema().getRandomTableNonEmptyTables(); + YSQLExpressionGenerator gen = new YSQLExpressionGenerator(globalState).setColumns(tables.getColumns()); + for (int i = 0; i < nrColumns; i++) { + columns.add(gen.generateExpression(0)); + } + YSQLSelect select = new YSQLSelect(); + select.setSelectType(SelectType.getRandom()); + if (select.getSelectOption() == SelectType.DISTINCT && Randomly.getBoolean()) { + select.setDistinctOnClause(gen.generateExpression(0)); + } + select.setFromList(tables.getTables().stream().map(t -> new YSQLFromTable(t, Randomly.getBoolean())) + .collect(Collectors.toList())); + select.setFetchColumns(columns); + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(0, YSQLDataType.BOOLEAN)); + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + if (Randomly.getBoolean()) { + select.setHavingClause(gen.generateHavingClause()); + } + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + if (Randomly.getBoolean()) { + select.setLimitClause(YSQLConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + if (Randomly.getBoolean()) { + select.setOffsetClause(YSQLConstant.createIntConstant(Randomly.getPositiveOrZeroNonCachedInteger())); + } + } + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setForClause(ForClause.getRandom()); + } + return select; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLReindexGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLReindexGenerator.java new file mode 100644 index 000000000..7526d8dc6 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLReindexGenerator.java @@ -0,0 +1,58 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLIndex; + +public final class YSQLReindexGenerator { + + private YSQLReindexGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + errors.add("could not create unique index"); // CONCURRENT INDEX + StringBuilder sb = new StringBuilder(); + sb.append("REINDEX"); + // if (Randomly.getBoolean()) { + // sb.append(" VERBOSE"); + // } + sb.append(" "); + Scope scope = Randomly.fromOptions(Scope.values()); + switch (scope) { + case INDEX: + sb.append("INDEX "); + List indexes = globalState.getSchema().getRandomTable().getIndexes(); + if (indexes.isEmpty()) { + throw new IgnoreMeException(); + } + sb.append(indexes.stream().map(YSQLIndex::getIndexName).collect(Collectors.joining())); + break; + case TABLE: + sb.append("TABLE "); + sb.append(globalState.getSchema().getRandomTable(t -> !t.isView()).getName()); + break; + case DATABASE: + sb.append("DATABASE "); + sb.append(globalState.getSchema().getDatabaseName()); + break; + default: + throw new AssertionError(scope); + } + errors.add("already contains data"); // FIXME bug report + errors.add("does not exist"); // internal index + errors.add("REINDEX is not yet implemented for partitioned indexes"); + return new SQLQueryAdapter(sb.toString(), errors); + } + + private enum Scope { + INDEX, TABLE, DATABASE + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLSequenceGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLSequenceGenerator.java new file mode 100644 index 000000000..3d7104b5e --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLSequenceGenerator.java @@ -0,0 +1,90 @@ +package sqlancer.yugabyte.ysql.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; + +public final class YSQLSequenceGenerator { + + private YSQLSequenceGenerator() { + } + + public static SQLQueryAdapter createSequence(YSQLGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder("CREATE"); + if (Randomly.getBoolean()) { + sb.append(" "); + sb.append(Randomly.fromOptions("TEMPORARY", "TEMP")); + } + sb.append(" SEQUENCE"); + // TODO keep track of sequences + sb.append(" IF NOT EXISTS"); + // TODO generate sequence names + sb.append(" seq"); + if (Randomly.getBoolean()) { + sb.append(" AS "); + sb.append(Randomly.fromOptions("smallint", "integer", "bigint")); + } + if (Randomly.getBoolean()) { + sb.append(" INCREMENT"); + if (Randomly.getBoolean()) { + sb.append(" BY"); + } + sb.append(" "); + sb.append(globalState.getRandomly().getInteger()); + errors.add("INCREMENT must not be zero"); + } + if (Randomly.getBoolean()) { + if (Randomly.getBoolean()) { + sb.append(" MINVALUE"); + sb.append(" "); + sb.append(globalState.getRandomly().getInteger()); + } else { + sb.append(" NO MINVALUE"); + } + errors.add("must be less than MAXVALUE"); + } + if (Randomly.getBoolean()) { + if (Randomly.getBoolean()) { + sb.append(" MAXVALUE"); + sb.append(" "); + sb.append(globalState.getRandomly().getInteger()); + } else { + sb.append(" NO MAXVALUE"); + } + errors.add("must be less than MAXVALUE"); + } + if (Randomly.getBoolean()) { + sb.append(" START"); + if (Randomly.getBoolean()) { + sb.append(" WITH"); + } + sb.append(" "); + sb.append(globalState.getRandomly().getInteger()); + errors.add("cannot be less than MINVALUE"); + errors.add("cannot be greater than MAXVALUE"); + } + if (Randomly.getBoolean()) { + sb.append(" CACHE "); + sb.append(globalState.getRandomly().getPositiveIntegerNotNull()); + } + errors.add("is out of range"); + if (Randomly.getBoolean()) { + if (Randomly.getBoolean()) { + sb.append(" NO"); + } + sb.append(" CYCLE"); + } + if (Randomly.getBoolean()) { + sb.append(" OWNED BY "); + // if (Randomly.getBoolean()) { + sb.append("NONE"); + // } else { + // sb.append(s.getRandomTable().getRandomColumn().getFullQualifiedName()); + // } + } + return new SQLQueryAdapter(sb.toString(), errors); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLSetGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLSetGenerator.java new file mode 100644 index 000000000..3406ca3de --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLSetGenerator.java @@ -0,0 +1,196 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.function.Function; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; + +public final class YSQLSetGenerator { + + private YSQLSetGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + StringBuilder sb = new StringBuilder(); + ArrayList options = new ArrayList<>(Arrays.asList(ConfigurationOption.values())); + options.remove(ConfigurationOption.DEFAULT_WITH_OIDS); + ConfigurationOption option = Randomly.fromList(options); + sb.append("SET "); + if (Randomly.getBoolean()) { + sb.append(Randomly.fromOptions("SESSION", "LOCAL")); + sb.append(" "); + } + sb.append(option.getOptionName()); + sb.append("="); + if (Randomly.getBoolean()) { + sb.append("DEFAULT"); + } else { + sb.append(option.op.apply(globalState.getRandomly())); + } + // todo avoiding props that are not represented in YSQL + ExpectedErrors errors = new ExpectedErrors(); + errors.add("unrecognized configuration parameter"); + errors.add("cannot be changed"); + + return new SQLQueryAdapter(sb.toString(), errors); + } + + private enum ConfigurationOption { + // YUGABYTE + YB_DEBUG_REPORT_ERROR_STACKTRACE("yb_debug_report_error_stacktrace", + (r) -> Randomly.fromOptions("false", "true")), + YB_DEBUG_LOG_CATCACHE_EVENTS("yb_debug_log_catcache_events", (r) -> Randomly.fromOptions("false", "true")), + YB_DEBUG_LOG_INTERNAL_RESTARTS("yb_debug_log_internal_restarts", (r) -> Randomly.fromOptions("false", "true")), + YB_DEBUG_LOG_DOCDB_REQUESTS("yb_debug_log_docdb_requests", (r) -> Randomly.fromOptions("false", "true")), + // YB_READ_FROM_FOLLOWERS("yb_read_from_followers", (r) -> Randomly.fromOptions("false", "true")), + YB_NON_DDL_TXN_FOR_SYS_TABLES_ALLOWED("yb_non_ddl_txn_for_sys_tables_allowed", + (r) -> Randomly.fromOptions("false", "true")), + YB_TRANSACTION_PRIORITY("yb_transaction_priority", + (r) -> Randomly.fromOptions(0, 0.1, 0.2, 0.3, 0.4, 1, 0.9, 0.8, 0.7, 0.6)), + YB_TRANSACTION_PRIORITY_LOWER_BOUND("yb_transaction_priority_lower_bound", + (r) -> Randomly.fromOptions(0, 0.1, 0.2, 0.3, 0.4)), + YB_TRANSACTION_PRIORITY_UPPER_BOUND("yb_transaction_priority_upper_bound", + (r) -> Randomly.fromOptions(1, 0.9, 0.8, 0.7, 0.6)), + YB_FORMAT_FUNCS_INCLUDE_YB_METADATA("yb_format_funcs_include_yb_metadata", + (r) -> Randomly.fromOptions("false", "true")), + YB_ENABLE_GEOLOCATION_COSTING("yb_enable_geolocation_costing", (r) -> Randomly.fromOptions("false", "true")), + YB_BINARY_RESTORE("yb_binary_restore", (r) -> Randomly.fromOptions("false", "true")), + YB_TEST_SYSTEM_CATALOGS_CREATION("yb_test_system_catalogs_creation", + (r) -> Randomly.fromOptions("false", "true")), + YB_TEST_FAIL_NEXT_DDL("yb_test_fail_next_ddl", (r) -> Randomly.fromOptions("false", "true")), + YB_DISABLE_TRANSACTIONAL_WRITES("yb_disable_transactional_writes", + (r) -> Randomly.fromOptions("false", "true")), + YB_ENABLE_OPTIMIZER_STATISTICS("yb_enable_optimizer_statistics", (r) -> Randomly.fromOptions("false", "true")), + YB_ENABLE_EXPRESSION_PUSHDOWN("yb_enable_expression_pushdown", (r) -> Randomly.fromOptions("false", "true")), + YB_ENABLE_UPSERT_MODE("yb_enable_upsert_mode", (r) -> Randomly.fromOptions("false", "true")), + YB_PLANNER_CUSTOM_PLAN_FOR_PARTITION_PRUNING("yb_planner_custom_plan_for_partition_pruning", + (r) -> Randomly.fromOptions("false", "true")), + YB_INDEX_STATE_FLAGS_UPDATE_DELAY("yb_index_state_flags_update_delay", + (r) -> Randomly.getNotCachedInteger(200, 1000)), + YB_TEST_PLANNER_CUSTOM_PLAN_THRESHOLD("yb_test_planner_custom_plan_threshold", + (r) -> Randomly.getNotCachedInteger(1, Integer.MAX_VALUE)), + // YSQL values + YSQL_UPGRADE_MODE("ysql_upgrade_mode", (r) -> Randomly.fromOptions("false", "true")), + YSQL_SESSION_MAX_BATCH_SIZE("ysql_session_max_batch_size", + (r) -> Randomly.getNotCachedInteger(1, Integer.MAX_VALUE)), + YSQL_MAX_IN_FLIGHT_OPS("ysql_max_in_flight_ops", (r) -> Randomly.getNotCachedInteger(1, Integer.MAX_VALUE)), + // https://www.postgresql.org/docs/11/runtime-config-wal.html + // This parameter can only be set at server start. + // WAL_LEVEL("wal_level", (r) -> Randomly.fromOptions("replica", "minimal", "logical")), + // FSYNC("fsync", (r) -> Randomly.fromOptions(1, 0)), + SYNCHRONOUS_COMMIT("synchronous_commit", + (r) -> Randomly.fromOptions("remote_apply", "remote_write", "local", "off")), + WAL_COMPRESSION("wal_compression", (r) -> Randomly.fromOptions(1, 0)), + // wal_buffer: server start + // wal_writer_delay: server start + // wal_writer_flush_after + COMMIT_DELAY("commit_delay", (r) -> r.getInteger(0, 100000)), + COMMIT_SIBLINGS("commit_siblings", (r) -> r.getInteger(0, 1000)), + // 19.5.2. Checkpoints + // checkpoint_timeout + // checkpoint_completion_target + // checkpoint_flush_after + // checkpoint_warning + // max_wal_size + // min_wal_size + // 19.5.3. Archiving + // archive_mode + // archive_command + // archive_timeout + // https://www.postgresql.org/docs/11/runtime-config-statistics.html + // 19.9.1. Query and Index Statistics Collector + TRACK_ACTIVITIES("track_activities", (r) -> Randomly.fromOptions(1, 0)), + // track_activity_query_size + TRACK_COUNTS("track_counts", (r) -> Randomly.fromOptions(1, 0)), + TRACK_IO_TIMING("track_io_timing", (r) -> Randomly.fromOptions(1, 0)), + TRACK_FUNCTIONS("track_functions", (r) -> Randomly.fromOptions("'none'", "'pl'", "'all'")), + // stats_temp_directory + // TODO 19.9.2. Statistics Monitoring + // https://www.postgresql.org/docs/11/runtime-config-autovacuum.html + // all can only be set at server-conf time + // 19.11. Client Connection Defaults + VACUUM_FREEZE_TABLE_AGE("vacuum_freeze_table_age", (r) -> Randomly.fromOptions(0, 5, 10, 100, 500, 2000000000)), + VACUUM_FREEZE_MIN_AGE("vacuum_freeze_min_age", (r) -> Randomly.fromOptions(0, 5, 10, 100, 500, 1000000000)), + VACUUM_MULTIXACT_FREEZE_TABLE_AGE("vacuum_multixact_freeze_table_age", + (r) -> Randomly.fromOptions(0, 5, 10, 100, 500, 2000000000)), + VACUUM_MULTIXACT_FREEZE_MIN_AGE("vacuum_multixact_freeze_min_age", + (r) -> Randomly.fromOptions(0, 5, 10, 100, 500, 1000000000)), + VACUUM_CLEANUP_INDEX_SCALE_FACTOR("vacuum_cleanup_index_scale_factor", + (r) -> Randomly.fromOptions(0.0, 0.0000001, 0.00001, 0.01, 0.1, 1, 10, 100, 100000, 10000000000.0)), + // TODO others + GIN_FUZZY_SEARCH_LIMIT("gin_fuzzy_search_limit", (r) -> r.getInteger(0, 2147483647)), + // 19.13. Version and Platform Compatibility + DEFAULT_WITH_OIDS("default_with_oids", (r) -> Randomly.fromOptions(0, 1)), + SYNCHRONIZED_SEQSCANS("synchronize_seqscans", (r) -> Randomly.fromOptions(0, 1)), + // https://www.postgresql.org/docs/devel/runtime-config-query.html + ENABLE_BITMAPSCAN("enable_bitmapscan", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_GATHERMERGE("enable_gathermerge", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_HASHJOIN("enable_hashjoin", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_INDEXSCAN("enable_indexscan", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_INDEXONLYSCAN("enable_indexonlyscan", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_MATERIAL("enable_material", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_MERGEJOIN("enable_mergejoin", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_NESTLOOP("enable_nestloop", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_PARALLEL_APPEND("enable_parallel_append", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_PARALLEL_HASH("enable_parallel_hash", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_PARTITION_PRUNING("enable_partition_pruning", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_PARTITIONWISE_JOIN("enable_partitionwise_join", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_PARTITIONWISE_AGGREGATE("enable_partitionwise_aggregate", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_SEGSCAN("enable_seqscan", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_SORT("enable_sort", (r) -> Randomly.fromOptions(1, 0)), + ENABLE_TIDSCAN("enable_tidscan", (r) -> Randomly.fromOptions(1, 0)), + // 19.7.2. Planner Cost Constants (complete as of March 2020) + // https://www.postgresql.org/docs/current/runtime-config-query.html#RUNTIME-CONFIG-QUERY-CONSTANTS + SEQ_PAGE_COST("seq_page_cost", (r) -> Randomly.fromOptions(0d, 0.00001, 0.05, 0.1, 1, 10, 10000)), + RANDOM_PAGE_COST("random_page_cost", (r) -> Randomly.fromOptions(0d, 0.00001, 0.05, 0.1, 1, 10, 10000)), + CPU_TUPLE_COST("cpu_tuple_cost", (r) -> Randomly.fromOptions(0d, 0.00001, 0.05, 0.1, 1, 10, 10000)), + CPU_INDEX_TUPLE_COST("cpu_index_tuple_cost", (r) -> Randomly.fromOptions(0d, 0.00001, 0.05, 0.1, 1, 10, 10000)), + CPU_OPERATOR_COST("cpu_operator_cost", (r) -> Randomly.fromOptions(0d, 0.000001, 0.0025, 0.1, 1, 10, 10000)), + PARALLEL_SETUP_COST("parallel_setup_cost", (r) -> r.getLong(0, Long.MAX_VALUE)), + PARALLEL_TUPLE_COST("parallel_tuple_cost", (r) -> r.getLong(0, Long.MAX_VALUE)), + MIN_PARALLEL_TABLE_SCAN_SIZE("min_parallel_table_scan_size", (r) -> r.getInteger(0, 715827882)), + MIN_PARALLEL_INDEX_SCAN_SIZE("min_parallel_index_scan_size", (r) -> r.getInteger(0, 715827882)), + EFFECTIVE_CACHE_SIZE("effective_cache_size", (r) -> r.getInteger(1, 2147483647)), + JIT_ABOVE_COST("jit_above_cost", (r) -> Randomly.fromOptions(0, r.getLong(-1, Long.MAX_VALUE - 1))), + JIT_INLINE_ABOVE_COST("jit_inline_above_cost", (r) -> Randomly.fromOptions(0, r.getLong(-1, Long.MAX_VALUE))), + JIT_OPTIMIZE_ABOVE_COST("jit_optimize_above_cost", + (r) -> Randomly.fromOptions(0, r.getLong(-1, Long.MAX_VALUE))), + // 19.7.3. Genetic Query Optimizer (complete as of March 2020) + // https://www.postgresql.org/docs/current/runtime-config-query.html#RUNTIME-CONFIG-QUERY-GEQO + GEQO("geqo", (r) -> Randomly.fromOptions(1, 0)), + GEQO_THRESHOLD("geqo_threshold", (r) -> r.getInteger(2, 2147483647)), + GEQO_EFFORT("geqo_effort", (r) -> r.getInteger(1, 10)), + GEQO_POO_SIZE("geqo_pool_size", (r) -> r.getInteger(0, 2147483647)), + GEQO_GENERATIONS("geqo_generations", (r) -> r.getInteger(0, 2147483647)), + GEQO_SELECTION_BIAS("geqo_selection_bias", (r) -> Randomly.fromOptions(1.5, 1.8, 2.0)), + GEQO_SEED("geqo_seed", (r) -> Randomly.fromOptions(0, 0.5, 1)), + // 19.7.4. Other Planner Options (complete as of March 2020) + // https://www.postgresql.org/docs/current/runtime-config-query.html#RUNTIME-CONFIG-QUERY-OTHER + DEFAULT_STATISTICS_TARGET("default_statistics_target", (r) -> r.getInteger(1, 10000)), + CONSTRAINT_EXCLUSION("constraint_exclusion", (r) -> Randomly.fromOptions("on", "off", "partition")), + CURSOR_TUPLE_FRACTION("cursor_tuple_fraction", + (r) -> Randomly.fromOptions(0.0, 0.1, 0.000001, 1, 0.5, 0.9999999)), + FROM_COLLAPSE_LIMIT("from_collapse_limit", (r) -> r.getInteger(1, Integer.MAX_VALUE)), + JIT("jit", (r) -> Randomly.fromOptions(1, 0)), + JOIN_COLLAPSE_LIMIT("join_collapse_limit", (r) -> r.getInteger(1, Integer.MAX_VALUE)), + PARALLEL_LEADER_PARTICIPATION("parallel_leader_participation", (r) -> Randomly.fromOptions(1, 0)), + FORCE_PARALLEL_MODE("force_parallel_mode", (r) -> Randomly.fromOptions("off", "on", "regress")); + + private final String optionName; + private final Function op; + + ConfigurationOption(String optionName, Function op) { + this.optionName = optionName; + this.op = op; + } + + public String getOptionName() { + return optionName; + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLStatisticsGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLStatisticsGenerator.java new file mode 100644 index 000000000..3b212b00b --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLStatisticsGenerator.java @@ -0,0 +1,74 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLStatisticsObject; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; + +public final class YSQLStatisticsGenerator { + + private YSQLStatisticsGenerator() { + } + + public static SQLQueryAdapter insert(YSQLGlobalState globalState) { + StringBuilder sb = new StringBuilder(); + sb.append("CREATE STATISTICS "); + if (Randomly.getBoolean()) { + sb.append(" IF NOT EXISTS"); + } + YSQLTable randomTable = globalState.getSchema().getRandomTable(t -> !t.isView()); // TODO materialized view + if (randomTable.getColumns().size() < 2) { + throw new IgnoreMeException(); + } + sb.append(" "); + sb.append(getNewStatisticsName(randomTable)); + if (Randomly.getBoolean()) { + sb.append(" ("); + List statsSubset; + statsSubset = Randomly.nonEmptySubset("ndistinct", "dependencies", "mcv"); + sb.append(String.join(", ", statsSubset)); + sb.append(")"); + } + + List randomColumns = randomTable.getRandomNonEmptyColumnSubset( + globalState.getRandomly().getInteger(2, randomTable.getColumns().size())); + sb.append(" ON "); + sb.append(randomColumns.stream().map(AbstractTableColumn::getName).collect(Collectors.joining(", "))); + sb.append(" FROM "); + sb.append(randomTable.getName()); + return new SQLQueryAdapter(sb.toString(), ExpectedErrors.from("cannot have more than 8 columns in statistics"), + true); + } + + public static SQLQueryAdapter remove(YSQLGlobalState globalState) { + StringBuilder sb = new StringBuilder("DROP STATISTICS "); + YSQLTable randomTable = globalState.getSchema().getRandomTable(); + List statistics = randomTable.getStatistics(); + if (statistics.isEmpty()) { + throw new IgnoreMeException(); + } + sb.append(Randomly.fromList(statistics).getName()); + return new SQLQueryAdapter(sb.toString(), true); + } + + private static String getNewStatisticsName(YSQLTable randomTable) { + List statistics = randomTable.getStatistics(); + int i = 0; + while (true) { + String candidateName = "s" + i; + if (statistics.stream().noneMatch(stat -> stat.getName().contentEquals(candidateName))) { + return candidateName; + } + i++; + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLTableGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLTableGenerator.java new file mode 100644 index 000000000..37945c7ac --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLTableGenerator.java @@ -0,0 +1,248 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; + +public class YSQLTableGenerator { + + protected final ExpectedErrors errors = new ExpectedErrors(); + private final String tableName; + private final StringBuilder sb = new StringBuilder(); + private final List columnsToBeAdded = new ArrayList<>(); + private final YSQLTable table; + private final boolean generateOnlyKnown; + private final YSQLGlobalState globalState; + private boolean columnCanHavePrimaryKey; + private boolean columnHasPrimaryKey; + private boolean isTemporaryTable; + + public YSQLTableGenerator(String tableName, boolean generateOnlyKnown, YSQLGlobalState globalState) { + this.tableName = tableName; + this.generateOnlyKnown = generateOnlyKnown; + this.globalState = globalState; + table = new YSQLTable(tableName, columnsToBeAdded, null, null, null, false, false); + // YB catalog specific messages + errors.add("The catalog snapshot used for this transaction has been invalidated"); + + errors.add("PRIMARY KEY containing column of type"); + errors.add("specified value cannot be cast to type boolean for column"); + errors.add("already exists"); + errors.add("invalid input syntax for"); + errors.add("is not unique"); + errors.add("integer out of range"); + errors.add("division by zero"); + errors.add("cannot create partitioned table as inheritance child"); + errors.add("cannot cast"); + errors.add("ERROR: functions in index expression must be marked IMMUTABLE"); + errors.add("functions in partition key expression must be marked IMMUTABLE"); + errors.add("functions in index predicate must be marked IMMUTABLE"); + errors.add("has no default operator class for access method"); + errors.add("does not exist for access method"); + errors.add("does not accept data type"); + errors.add("but default expression is of type text"); + errors.add("has pseudo-type unknown"); + errors.add("no collation was derived for partition key column"); + errors.add("cannot set colocated true on a non-colocated database"); + errors.add("Cannot split table that does not have primary key"); + errors.add("inherits from generated column but specifies identity"); + errors.add("inherits from generated column but specifies default"); + YSQLErrors.addCommonExpressionErrors(errors); + YSQLErrors.addCommonTableErrors(errors); + } + + public static SQLQueryAdapter generate(String tableName, boolean generateOnlyKnown, YSQLGlobalState globalState) { + return new YSQLTableGenerator(tableName, generateOnlyKnown, globalState).generate(); + } + + private SQLQueryAdapter generate() { + columnCanHavePrimaryKey = true; + sb.append("CREATE"); + if (Randomly.getBooleanWithSmallProbability()) { + sb.append(" "); + isTemporaryTable = true; + sb.append(Randomly.fromOptions("TEMPORARY", "TEMP")); + } + sb.append(" TABLE"); + if (Randomly.getBoolean()) { + sb.append(" IF NOT EXISTS"); + } + sb.append(" "); + sb.append(tableName); + createStandard(); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + + private void createStandard() throws AssertionError { + sb.append("("); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + String name = DBMSCommon.createColumnName(i); + createColumn(name); + } + if (Randomly.getBoolean()) { + errors.add("constraints on temporary tables may reference only temporary tables"); + errors.add("constraints on unlogged tables may reference only permanent or unlogged tables"); + errors.add("constraints on permanent tables may reference only permanent tables"); + errors.add("cannot be implemented"); + errors.add("there is no unique constraint matching given keys for referenced table"); + errors.add("cannot reference partitioned table"); + errors.add("unsupported ON COMMIT and foreign key combination"); + errors.add("ERROR: invalid ON DELETE action for foreign key constraint containing generated column"); + errors.add("exclusion constraints are not supported on partitioned tables"); + errors.add("option is not yet supported for hash partitioned tables"); + YSQLCommon.addTableConstraints(columnHasPrimaryKey, sb, table, globalState, errors); + } + sb.append(")"); + generatePartitionBy(); + YSQLCommon.generateWith(sb, globalState, errors, columnsToBeAdded, isTemporaryTable); + if (Randomly.getBoolean() && isTemporaryTable) { + sb.append(" ON COMMIT "); + // todo ON COMMIT DROP fails and it's known issue + // sb.append(Randomly.fromOptions("PRESERVE ROWS", "DELETE ROWS", "DROP")); + sb.append(Randomly.fromOptions("PRESERVE ROWS", "DELETE ROWS")); + sb.append(" "); + } + } + + private void createColumn(String name) throws AssertionError { + sb.append(name); + sb.append(" "); + YSQLDataType type = YSQLDataType.getRandomType(); + boolean serial = YSQLCommon.appendDataType(type, sb, true, generateOnlyKnown, globalState.getCollates()); + YSQLColumn c = new YSQLColumn(name, type); + c.setTable(table); + columnsToBeAdded.add(c); + sb.append(" "); + if (Randomly.getBoolean()) { + createColumnConstraint(type, serial); + } + } + + private void generatePartitionBy() { + if (Randomly.getBoolean()) { + return; + } + sb.append(" PARTITION BY "); + // TODO "RANGE", + String partitionOption = Randomly.fromOptions("RANGE", "LIST", "HASH"); + sb.append(partitionOption); + sb.append("("); + errors.add("unrecognized parameter"); + errors.add("cannot use constant expression"); + errors.add("unrecognized parameter"); + errors.add("unsupported PRIMARY KEY constraint with partition key definition"); + errors.add("which is part of the partition key."); + errors.add("unsupported UNIQUE constraint with partition key definition"); + errors.add("does not accept data type"); + int n = partitionOption.contentEquals("LIST") ? 1 : Randomly.smallNumber() + 1; + YSQLErrors.addCommonExpressionErrors(errors); + for (int i = 0; i < n; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append("("); + YSQLExpression expr = YSQLExpressionGenerator.generateExpression(globalState, columnsToBeAdded); + sb.append(YSQLVisitor.asString(expr)); + sb.append(")"); + if (Randomly.getBoolean()) { + sb.append(globalState.getRandomOpclass()); + errors.add("does not exist for access method"); + } + } + sb.append(")"); + } + + private void createColumnConstraint(YSQLDataType type, boolean serial) { + List constraintSubset = Randomly.nonEmptySubset(ColumnConstraint.values()); + if (Randomly.getBoolean()) { + // make checks constraints less likely + constraintSubset.remove(ColumnConstraint.CHECK); + } + if (!columnCanHavePrimaryKey || columnHasPrimaryKey) { + constraintSubset.remove(ColumnConstraint.PRIMARY_KEY); + } + if (constraintSubset.contains(ColumnConstraint.GENERATED) + && constraintSubset.contains(ColumnConstraint.DEFAULT)) { + // otherwise: ERROR: both default and identity specified for column + constraintSubset.remove(Randomly.fromOptions(ColumnConstraint.GENERATED, ColumnConstraint.DEFAULT)); + } + if (constraintSubset.contains(ColumnConstraint.GENERATED) && type != YSQLDataType.INT) { + // otherwise: ERROR: identity column type must be smallint, integer, or bigint + constraintSubset.remove(ColumnConstraint.GENERATED); + } + if (serial) { + constraintSubset.remove(ColumnConstraint.GENERATED); + constraintSubset.remove(ColumnConstraint.DEFAULT); + constraintSubset.remove(ColumnConstraint.NULL_OR_NOT_NULL); + + } + for (ColumnConstraint c : constraintSubset) { + sb.append(" "); + switch (c) { + case NULL_OR_NOT_NULL: + sb.append(Randomly.fromOptions("NOT NULL", "NULL")); + errors.add("conflicting NULL/NOT NULL declarations"); + break; + case UNIQUE: + sb.append("UNIQUE"); + break; + case PRIMARY_KEY: + sb.append("PRIMARY KEY"); + columnHasPrimaryKey = true; + break; + case DEFAULT: + sb.append("DEFAULT"); + sb.append(" ("); + sb.append(YSQLVisitor.asString(YSQLExpressionGenerator.generateExpression(globalState, type))); + sb.append(")"); + errors.add("out of range"); + errors.add("is a generated column"); + break; + case CHECK: + sb.append("CHECK ("); + sb.append(YSQLVisitor.asString(YSQLExpressionGenerator.generateExpression(globalState, columnsToBeAdded, + YSQLDataType.BOOLEAN))); + sb.append(")"); + errors.add("out of range"); + break; + case GENERATED: + sb.append("GENERATED "); + if (Randomly.getBoolean()) { + sb.append(" ALWAYS AS ("); + sb.append(YSQLVisitor + .asString(YSQLExpressionGenerator.generateExpression(globalState, columnsToBeAdded, type))); + sb.append(") STORED"); + errors.add("A generated column cannot reference another generated column."); + errors.add("cannot use generated column in partition key"); + errors.add("generation expression is not immutable"); + errors.add("cannot use column reference in DEFAULT expression"); + } else { + sb.append(Randomly.fromOptions("ALWAYS", "BY DEFAULT")); + sb.append(" AS IDENTITY"); + } + break; + default: + throw new AssertionError(sb); + } + } + } + + private enum ColumnConstraint { + NULL_OR_NOT_NULL, UNIQUE, PRIMARY_KEY, DEFAULT, CHECK, GENERATED + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLTableGroupGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLTableGroupGenerator.java new file mode 100644 index 000000000..8bc292348 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLTableGroupGenerator.java @@ -0,0 +1,25 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.concurrent.atomic.AtomicLong; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; + +public final class YSQLTableGroupGenerator { + + // TODO rework + public static final AtomicLong UNIQUE_TABLEGROUP_COUNTER = new AtomicLong(1); + + private YSQLTableGroupGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder("CREATE TABLEGROUP "); + String tableGroupName = "tg" + UNIQUE_TABLEGROUP_COUNTER.incrementAndGet(); + sb.append(tableGroupName); + errors.add("cannot use tablegroups in a colocated database"); + return new SQLQueryAdapter(sb.toString(), errors, true); + } +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLTransactionGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLTransactionGenerator.java new file mode 100644 index 000000000..67780c016 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLTransactionGenerator.java @@ -0,0 +1,27 @@ +package sqlancer.yugabyte.ysql.gen; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; + +public final class YSQLTransactionGenerator { + + private YSQLTransactionGenerator() { + } + + public static SQLQueryAdapter executeBegin() { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder("BEGIN"); + if (Randomly.getBoolean()) { + errors.add("SET TRANSACTION ISOLATION LEVEL must be called before any query"); + sb.append(" ISOLATION LEVEL "); + sb.append(Randomly.fromOptions("SERIALIZABLE", "REPEATABLE READ", "READ COMMITTED")); + // if (Randomly.getBoolean()) { + // sb.append(" "); + // sb.append(Randomly.fromOptions("READ WRITE", "READ ONLY")); + // } + } + return new SQLQueryAdapter(sb.toString(), errors, true); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLTruncateGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLTruncateGenerator.java new file mode 100644 index 000000000..449654d99 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLTruncateGenerator.java @@ -0,0 +1,29 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.schema.AbstractTable; +import sqlancer.yugabyte.ysql.YSQLGlobalState; + +public final class YSQLTruncateGenerator { + + private YSQLTruncateGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + StringBuilder sb = new StringBuilder(); + sb.append("TRUNCATE"); + if (Randomly.getBoolean()) { + sb.append(" TABLE"); + } + sb.append(" "); + sb.append(globalState.getSchema().getDatabaseTablesRandomSubsetNotEmpty().stream().map(AbstractTable::getName) + .collect(Collectors.joining(", "))); + return new SQLQueryAdapter(sb.toString(), ExpectedErrors + .from("cannot truncate a table referenced in a foreign key constraint", "is not a table")); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLUpdateGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLUpdateGenerator.java new file mode 100644 index 000000000..bc7b00d79 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLUpdateGenerator.java @@ -0,0 +1,82 @@ +package sqlancer.yugabyte.ysql.gen; + +import java.util.Arrays; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.gen.AbstractUpdateGenerator; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; + +public final class YSQLUpdateGenerator extends AbstractUpdateGenerator { + + private final YSQLGlobalState globalState; + private YSQLTable randomTable; + + private YSQLUpdateGenerator(YSQLGlobalState globalState) { + this.globalState = globalState; + errors.addAll(Arrays.asList("conflicting key value violates exclusion constraint", + "reached maximum value of sequence", "violates foreign key constraint", "violates not-null constraint", + "violates unique constraint", "out of range", "cannot cast", "must be type boolean", "is not unique", + " bit string too long", "can only be updated to DEFAULT", "division by zero", + "You might need to add explicit type casts.", "invalid regular expression", + "View columns that are not columns of their base relation are not updatable")); + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + return new YSQLUpdateGenerator(globalState).generate(); + } + + private SQLQueryAdapter generate() { + randomTable = globalState.getSchema().getRandomTable(YSQLTable::isInsertable); + List columns = randomTable.getRandomNonEmptyColumnSubset(); + sb.append("UPDATE "); + sb.append(randomTable.getName()); + sb.append(" SET "); + errors.add("multiple assignments to same column"); // view whose columns refer to a column in the referenced + // table multiple times + errors.add("new row violates check option for view"); + YSQLErrors.addCommonInsertUpdateErrors(errors); + + updateColumns(columns); + errors.add("invalid input syntax for "); + errors.add("operator does not exist: text = boolean"); + errors.add("violates check constraint"); + errors.add("could not determine which collation to use for string comparison"); + errors.add("but expression is of type"); + YSQLErrors.addCommonExpressionErrors(errors); + if (!Randomly.getBooleanWithSmallProbability()) { + sb.append(" WHERE "); + YSQLExpression where = YSQLExpressionGenerator.generateExpression(globalState, randomTable.getColumns(), + YSQLDataType.BOOLEAN); + sb.append(YSQLVisitor.asString(where)); + } + + return new SQLQueryAdapter(sb.toString(), errors, true); + } + + @Override + protected void updateValue(YSQLColumn column) { + if (!Randomly.getBoolean()) { + YSQLExpression constant = YSQLExpressionGenerator.generateConstant(globalState.getRandomly(), + column.getType()); + sb.append(YSQLVisitor.asString(constant)); + } else if (Randomly.getBoolean()) { + sb.append("DEFAULT"); + } else { + sb.append("("); + YSQLExpression expr = YSQLExpressionGenerator.generateExpression(globalState, randomTable.getColumns(), + column.getType()); + // caused by casts + sb.append(YSQLVisitor.asString(expr)); + sb.append(")"); + } + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLVacuumGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLVacuumGenerator.java new file mode 100644 index 000000000..9efe027b1 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLVacuumGenerator.java @@ -0,0 +1,16 @@ +package sqlancer.yugabyte.ysql.gen; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLGlobalState; + +public final class YSQLVacuumGenerator { + + private YSQLVacuumGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + return new SQLQueryAdapter("VACUUM", ExpectedErrors.from("VACUUM cannot run inside a transaction block")); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/gen/YSQLViewGenerator.java b/src/sqlancer/yugabyte/ysql/gen/YSQLViewGenerator.java new file mode 100644 index 000000000..3005d49f8 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/gen/YSQLViewGenerator.java @@ -0,0 +1,52 @@ +package sqlancer.yugabyte.ysql.gen; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; + +public final class YSQLViewGenerator { + + private YSQLViewGenerator() { + } + + public static SQLQueryAdapter create(YSQLGlobalState globalState) { + ExpectedErrors errors = new ExpectedErrors(); + StringBuilder sb = new StringBuilder("CREATE"); + if (Randomly.getBoolean()) { + sb.append(" MATERIALIZED"); + } else { + if (Randomly.getBoolean()) { + sb.append(" OR REPLACE"); + } + if (Randomly.getBoolean()) { + sb.append(Randomly.fromOptions(" TEMP", " TEMPORARY")); + } + } + sb.append(" VIEW "); + String name = globalState.getSchema().getFreeViewName(); + sb.append(name); + sb.append("("); + int nrColumns = Randomly.smallNumber() + 1; + for (int i = 0; i < nrColumns; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(DBMSCommon.createColumnName(i)); + } + sb.append(")"); + sb.append(" AS ("); + YSQLSelect select = YSQLRandomQueryGenerator.createRandomQuery(nrColumns, globalState); + sb.append(YSQLVisitor.asString(select)); + sb.append(")"); + YSQLErrors.addGroupingErrors(errors); + YSQLErrors.addViewErrors(errors); + YSQLErrors.addCommonExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, true); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/oracle/YSQLCatalog.java b/src/sqlancer/yugabyte/ysql/oracle/YSQLCatalog.java new file mode 100644 index 000000000..e5ae3fe1b --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/oracle/YSQLCatalog.java @@ -0,0 +1,96 @@ +package sqlancer.yugabyte.ysql.oracle; + +import static sqlancer.yugabyte.ysql.YSQLProvider.DDL_LOCK; + +import java.util.Arrays; +import java.util.List; + +import sqlancer.IgnoreMeException; +import sqlancer.Main; +import sqlancer.MainOptions; +import sqlancer.SQLConnection; +import sqlancer.common.DBMSCommon; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLProvider; +import sqlancer.yugabyte.ysql.gen.YSQLTableGenerator; + +public class YSQLCatalog implements TestOracle { + protected final YSQLGlobalState state; + + protected final ExpectedErrors errors = new ExpectedErrors(); + protected final Main.StateLogger logger; + protected final MainOptions options; + protected final SQLConnection con; + + private final List dmlActions = Arrays.asList(YSQLProvider.Action.INSERT, + YSQLProvider.Action.UPDATE, YSQLProvider.Action.DELETE); + private final List catalogActions = Arrays.asList(YSQLProvider.Action.CREATE_VIEW, + YSQLProvider.Action.CREATE_SEQUENCE, YSQLProvider.Action.ALTER_TABLE, YSQLProvider.Action.SET_CONSTRAINTS, + YSQLProvider.Action.DISCARD, YSQLProvider.Action.DROP_INDEX, YSQLProvider.Action.COMMENT_ON, + YSQLProvider.Action.RESET_ROLE, YSQLProvider.Action.RESET); + private final List diskActions = Arrays.asList(YSQLProvider.Action.TRUNCATE, + YSQLProvider.Action.VACUUM); + + public YSQLCatalog(YSQLGlobalState globalState) { + this.state = globalState; + this.con = state.getConnection(); + this.logger = state.getLogger(); + this.options = state.getOptions(); + YSQLErrors.addCommonExpressionErrors(errors); + YSQLErrors.addCommonFetchErrors(errors); + } + + private YSQLProvider.Action getRandomAction(List actions) { + return actions.get(state.getRandomly().getInteger(0, actions.size())); + } + + protected void createTables(YSQLGlobalState globalState, int numTables) throws Exception { + synchronized (DDL_LOCK) { + while (globalState.getSchema().getDatabaseTables().size() < numTables) { + // TODO concurrent DDLs may produce a lot of noise in test logs so its disabled right now + // added timeout to avoid possible catalog collisions + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new AssertionError(); + } + + try { + String tableName = DBMSCommon.createTableName(globalState.getSchema().getDatabaseTables().size()); + SQLQueryAdapter createTable = YSQLTableGenerator.generate(tableName, true, globalState); + globalState.executeStatement(createTable); + globalState.getManager().incrementSelectQueryCount(); + globalState.executeStatement(new SQLQueryAdapter("COMMIT", true)); + } catch (IgnoreMeException e) { + // do nothing + } + } + } + } + + @Override + public void check() throws Exception { + // create table or evaluate catalog test + int seed = state.getRandomly().getInteger(1, 100); + if (seed > 95) { + createTables(state, 1); + } else { + YSQLProvider.Action randomAction; + + if (seed > 40) { + randomAction = getRandomAction(dmlActions); + } else if (seed > 10) { + randomAction = getRandomAction(catalogActions); + } else { + randomAction = getRandomAction(diskActions); + } + + state.executeStatement(randomAction.getQuery(state)); + } + state.getManager().incrementSelectQueryCount(); + } +} diff --git a/src/sqlancer/yugabyte/ysql/oracle/YSQLFuzzer.java b/src/sqlancer/yugabyte/ysql/oracle/YSQLFuzzer.java new file mode 100644 index 000000000..8d0cea997 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/oracle/YSQLFuzzer.java @@ -0,0 +1,82 @@ +package sqlancer.yugabyte.ysql.oracle; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLProvider; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.gen.YSQLRandomQueryGenerator; + +public class YSQLFuzzer implements TestOracle { + private final YSQLGlobalState globalState; + private final List testQueries; + private final ExpectedErrors errors = new ExpectedErrors(); + + public YSQLFuzzer(YSQLGlobalState globalState) { + this.globalState = globalState; + + YSQLErrors.addCommonExpressionErrors(errors); + YSQLErrors.addCommonFetchErrors(errors); + YSQLErrors.addGroupingErrors(errors); + YSQLErrors.addViewErrors(errors); + + // remove timeout error from scope + errors.add("canceling statement due to statement timeout"); + + // exclude nemesis exceptions + errors.add("terminating connection due to administrator command"); + errors.add("Java heap space"); + errors.add("Connection refused"); + errors.add("Connection to"); + + testQueries = new ArrayList<>(); + + testQueries.add(new SelectQuery()); + testQueries.add(new ActionQuery(YSQLProvider.Action.UPDATE)); + testQueries.add(new ActionQuery(YSQLProvider.Action.DELETE)); + testQueries.add(new ActionQuery(YSQLProvider.Action.INSERT)); + } + + @Override + public void check() throws Exception { + Query s = testQueries.get(globalState.getRandomly().getInteger(0, testQueries.size())); + globalState.executeStatement(s.getQuery(globalState, errors)); + globalState.getManager().incrementSelectQueryCount(); + } + + private static class Query { + public SQLQueryAdapter getQuery(YSQLGlobalState state, ExpectedErrors errors) throws Exception { + throw new IllegalAccessException("Should be implemented"); + }; + } + + private static class ActionQuery extends Query { + private final YSQLProvider.Action action; + + ActionQuery(YSQLProvider.Action action) { + this.action = action; + } + + @Override + public SQLQueryAdapter getQuery(YSQLGlobalState state, ExpectedErrors errors) throws Exception { + return action.getQuery(state); + } + } + + private static class SelectQuery extends Query { + + @Override + public SQLQueryAdapter getQuery(YSQLGlobalState state, ExpectedErrors errors) throws Exception { + return new SQLQueryAdapter( + YSQLVisitor.asString(YSQLRandomQueryGenerator.createRandomQuery(Randomly.smallNumber() + 1, state)) + + ";", + errors); + } + } +} diff --git a/src/sqlancer/yugabyte/ysql/oracle/YSQLPivotedQuerySynthesisOracle.java b/src/sqlancer/yugabyte/ysql/oracle/YSQLPivotedQuerySynthesisOracle.java new file mode 100644 index 000000000..8f54aa698 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/oracle/YSQLPivotedQuerySynthesisOracle.java @@ -0,0 +1,147 @@ +package sqlancer.yugabyte.ysql.oracle; + +import java.sql.SQLException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.oracle.PivotedQuerySynthesisBase; +import sqlancer.common.query.Query; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLRowValue; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTables; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLColumnValue; +import sqlancer.yugabyte.ysql.ast.YSQLConstant; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; +import sqlancer.yugabyte.ysql.gen.YSQLExpressionGenerator; + +public class YSQLPivotedQuerySynthesisOracle + extends PivotedQuerySynthesisBase { + + private List fetchColumns; + + public YSQLPivotedQuerySynthesisOracle(YSQLGlobalState globalState) throws SQLException { + super(globalState); + YSQLErrors.addCommonExpressionErrors(errors); + YSQLErrors.addCommonFetchErrors(errors); + } + + /* + * Prevent name collisions by aliasing the column. + */ + private YSQLColumn getFetchValueAliasedColumn(YSQLColumn c) { + YSQLColumn aliasedColumn = new YSQLColumn(c.getName() + " AS " + c.getTable().getName() + c.getName(), + c.getType()); + aliasedColumn.setTable(c.getTable()); + return aliasedColumn; + } + + private List generateGroupByClause(List columns, YSQLRowValue rw) { + if (Randomly.getBoolean()) { + return columns.stream().map(c -> YSQLColumnValue.create(c, rw.getValues().get(c))) + .collect(Collectors.toList()); + } else { + return Collections.emptyList(); + } + } + + private YSQLConstant generateLimit() { + if (Randomly.getBoolean()) { + return YSQLConstant.createIntConstant(Integer.MAX_VALUE); + } else { + return null; + } + } + + private YSQLExpression generateOffset() { + if (Randomly.getBoolean()) { + return YSQLConstant.createIntConstant(0); + } else { + return null; + } + } + + private YSQLExpression generateRectifiedExpression(List columns, YSQLRowValue rw) { + YSQLExpression expr = new YSQLExpressionGenerator(globalState).setColumns(columns).setRowValue(rw) + .generateExpressionWithExpectedResult(YSQLDataType.BOOLEAN); + YSQLExpression result; + if (expr.getExpectedValue().isNull()) { + result = YSQLPostfixOperation.create(expr, YSQLPostfixOperation.PostfixOperator.IS_NULL); + } else { + result = YSQLPostfixOperation.create(expr, expr.getExpectedValue().cast(YSQLDataType.BOOLEAN).asBoolean() + ? YSQLPostfixOperation.PostfixOperator.IS_TRUE : YSQLPostfixOperation.PostfixOperator.IS_FALSE); + } + rectifiedPredicates.add(result); + return result; + } + + @Override + protected Query getContainmentCheckQuery(Query query) throws SQLException { + StringBuilder sb = new StringBuilder(); + sb.append("SELECT * FROM ("); // ANOTHER SELECT TO USE ORDER BY without restrictions + sb.append(query.getUnterminatedQueryString()); + sb.append(") as result WHERE "); + int i = 0; + for (YSQLColumn c : fetchColumns) { + if (i++ != 0) { + sb.append(" AND "); + } + sb.append("result."); + sb.append(c.getTable().getName()); + sb.append(c.getName()); + if (pivotRow.getValues().get(c).isNull()) { + sb.append(" IS NULL"); + } else { + sb.append(" = "); + sb.append(pivotRow.getValues().get(c).getTextRepresentation()); + } + } + String resultingQueryString = sb.toString(); + return new SQLQueryAdapter(resultingQueryString, errors); + } + + @Override + public SQLQueryAdapter getRectifiedQuery() throws SQLException { + YSQLTables randomFromTables = globalState.getSchema().getRandomTableNonEmptyTables(); + + YSQLSelect selectStatement = new YSQLSelect(); + selectStatement.setSelectType(Randomly.fromOptions(YSQLSelect.SelectType.values())); + List columns = randomFromTables.getColumns(); + pivotRow = randomFromTables.getRandomRowValue(globalState.getConnection()); + + fetchColumns = columns; + selectStatement.setFromList(randomFromTables.getTables().stream() + .map(t -> new YSQLSelect.YSQLFromTable(t, false)).collect(Collectors.toList())); + selectStatement.setFetchColumns(fetchColumns.stream() + .map(c -> new YSQLColumnValue(getFetchValueAliasedColumn(c), pivotRow.getValues().get(c))) + .collect(Collectors.toList())); + YSQLExpression whereClause = generateRectifiedExpression(columns, pivotRow); + selectStatement.setWhereClause(whereClause); + List groupByClause = generateGroupByClause(columns, pivotRow); + selectStatement.setGroupByExpressions(groupByClause); + YSQLExpression limitClause = generateLimit(); + selectStatement.setLimitClause(limitClause); + if (limitClause != null) { + YSQLExpression offsetClause = generateOffset(); + selectStatement.setOffsetClause(offsetClause); + } + List orderBy = new YSQLExpressionGenerator(globalState).setColumns(columns).generateOrderBys(); + selectStatement.setOrderByClauses(orderBy); + return new SQLQueryAdapter(YSQLVisitor.asString(selectStatement)); + } + + @Override + protected String getExpectedValues(YSQLExpression expr) { + return YSQLVisitor.asExpectedValues(expr); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPAggregateOracle.java b/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPAggregateOracle.java new file mode 100644 index 000000000..3fdaf3f9a --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPAggregateOracle.java @@ -0,0 +1,194 @@ +package sqlancer.yugabyte.ysql.oracle.tlp; + +import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.postgresql.util.PSQLException; + +import sqlancer.ComparatorHelper; +import sqlancer.IgnoreMeException; +import sqlancer.Randomly; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLancerResultSet; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLAggregate; +import sqlancer.yugabyte.ysql.ast.YSQLAggregate.YSQLAggregateFunction; +import sqlancer.yugabyte.ysql.ast.YSQLAlias; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; +import sqlancer.yugabyte.ysql.ast.YSQLJoin; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLPostfixOperation.PostfixOperator; +import sqlancer.yugabyte.ysql.ast.YSQLPrefixOperation; +import sqlancer.yugabyte.ysql.ast.YSQLPrefixOperation.PrefixOperator; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; + +public class YSQLTLPAggregateOracle extends YSQLTLPBase implements TestOracle { + + private String firstResult; + private String secondResult; + private String originalQuery; + private String metamorphicQuery; + + public YSQLTLPAggregateOracle(YSQLGlobalState state) { + super(state); + YSQLErrors.addGroupingErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + aggregateCheck(); + } + + protected void aggregateCheck() throws SQLException { + YSQLAggregateFunction aggregateFunction = Randomly.fromOptions(YSQLAggregateFunction.MAX, + YSQLAggregateFunction.MIN, YSQLAggregateFunction.SUM, YSQLAggregateFunction.BIT_AND, + YSQLAggregateFunction.BIT_OR, YSQLAggregateFunction.BOOL_AND, YSQLAggregateFunction.BOOL_OR, + YSQLAggregateFunction.COUNT); + YSQLAggregate aggregate = gen.generateArgsForAggregate(aggregateFunction.getRandomReturnType(), + aggregateFunction); + List fetchColumns = new ArrayList<>(); + fetchColumns.add(aggregate); + while (Randomly.getBooleanWithRatherLowProbability()) { + fetchColumns.add(gen.generateAggregate()); + } + select.setFetchColumns(Arrays.asList(aggregate)); + if (Randomly.getBooleanWithRatherLowProbability()) { + select.setOrderByClauses(gen.generateOrderBys()); + } + originalQuery = YSQLVisitor.asString(select); + firstResult = getAggregateResult(originalQuery); + metamorphicQuery = createMetamorphicUnionQuery(select, aggregate, select.getFromList()); + secondResult = getAggregateResult(metamorphicQuery); + + String queryFormatString = "-- %s;\n-- result: %s"; + String firstQueryString = String.format(queryFormatString, originalQuery, firstResult); + String secondQueryString = String.format(queryFormatString, metamorphicQuery, secondResult); + state.getState().getLocalState().log(String.format("%s\n%s", firstQueryString, secondQueryString)); + if (firstResult == null && secondResult != null || firstResult != null && secondResult == null + || firstResult != null && !firstResult.contentEquals(secondResult) + && !ComparatorHelper.isEqualDouble(firstResult, secondResult)) { + if (secondResult != null && secondResult.contains("Inf")) { + throw new IgnoreMeException(); // FIXME: average computation + } + String assertionMessage = String.format("the results mismatch!\n%s\n%s", firstQueryString, + secondQueryString); + throw new AssertionError(assertionMessage); + } + } + + private String createMetamorphicUnionQuery(YSQLSelect select, YSQLAggregate aggregate, List from) { + String metamorphicQuery; + YSQLExpression whereClause = gen.generateExpression(YSQLDataType.BOOLEAN); + YSQLExpression negatedClause = new YSQLPrefixOperation(whereClause, PrefixOperator.NOT); + YSQLExpression notNullClause = new YSQLPostfixOperation(whereClause, PostfixOperator.IS_NULL); + List mappedAggregate = mapped(aggregate); + YSQLSelect leftSelect = getSelect(mappedAggregate, from, whereClause, select.getJoinClauses()); + YSQLSelect middleSelect = getSelect(mappedAggregate, from, negatedClause, select.getJoinClauses()); + YSQLSelect rightSelect = getSelect(mappedAggregate, from, notNullClause, select.getJoinClauses()); + metamorphicQuery = "SELECT " + getOuterAggregateFunction(aggregate) + " FROM ("; + metamorphicQuery += YSQLVisitor.asString(leftSelect) + " UNION ALL " + YSQLVisitor.asString(middleSelect) + + " UNION ALL " + YSQLVisitor.asString(rightSelect); + metamorphicQuery += ") as asdf"; + return metamorphicQuery; + } + + private String getAggregateResult(String queryString) throws SQLException { + // log TLP Aggregate SELECT queries on the current log file + if (state.getOptions().logEachSelect()) { + // TODO: refactor me + state.getLogger().writeCurrent(queryString); + try { + state.getLogger().getCurrentFileWriter().flush(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + String resultString; + SQLQueryAdapter q = new SQLQueryAdapter(queryString, errors); + try (SQLancerResultSet result = q.executeAndGet(state)) { + if (result == null) { + throw new IgnoreMeException(); + } + if (!result.next()) { + resultString = null; + } else { + resultString = result.getString(1); + } + } catch (PSQLException e) { + throw new AssertionError(queryString, e); + } + return resultString; + } + + private List mapped(YSQLAggregate aggregate) { + switch (aggregate.getFunction()) { + case SUM: + case COUNT: + case BIT_AND: + case BIT_OR: + case BOOL_AND: + case BOOL_OR: + case MAX: + case MIN: + return aliasArgs(Arrays.asList(aggregate)); + // case AVG: + //// List arg = Arrays.asList(new + // YSQLCast(aggregate.getExpr().get(0), + // YSQLDataType.DECIMAL.get())); + // YSQLAggregate sum = new YSQLAggregate(YSQLAggregateFunction.SUM, + // aggregate.getExpr()); + // YSQLCast count = new YSQLCast( + // new YSQLAggregate(YSQLAggregateFunction.COUNT, aggregate.getExpr()), + // YSQLDataType.DECIMAL.get()); + //// YSQLBinaryArithmeticOperation avg = new + // YSQLBinaryArithmeticOperation(sum, count, + // YSQLBinaryArithmeticOperator.DIV); + // return aliasArgs(Arrays.asList(sum, count)); + default: + throw new AssertionError(aggregate.getFunction()); + } + } + + private List aliasArgs(List originalAggregateArgs) { + List args = new ArrayList<>(); + int i = 0; + for (YSQLExpression expr : originalAggregateArgs) { + args.add(new YSQLAlias(expr, "agg" + i++)); + } + return args; + } + + private String getOuterAggregateFunction(YSQLAggregate aggregate) { + switch (aggregate.getFunction()) { + // case AVG: + // return "SUM(agg0::DECIMAL)/SUM(agg1)::DECIMAL"; + case COUNT: + return YSQLAggregateFunction.SUM + "(agg0)"; + default: + return aggregate.getFunction().toString() + "(agg0)"; + } + } + + private YSQLSelect getSelect(List aggregates, List from, YSQLExpression whereClause, + List joinList) { + YSQLSelect leftSelect = new YSQLSelect(); + leftSelect.setFetchColumns(aggregates); + leftSelect.setFromList(from); + leftSelect.setWhereClause(whereClause); + leftSelect.setJoinClauses(joinList); + if (Randomly.getBooleanWithSmallProbability()) { + leftSelect.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + } + return leftSelect; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPBase.java b/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPBase.java new file mode 100644 index 000000000..57fe246e3 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPBase.java @@ -0,0 +1,106 @@ +package sqlancer.yugabyte.ysql.oracle.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.gen.ExpressionGenerator; +import sqlancer.common.oracle.TernaryLogicPartitioningOracleBase; +import sqlancer.common.oracle.TestOracle; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTables; +import sqlancer.yugabyte.ysql.ast.YSQLColumnValue; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; +import sqlancer.yugabyte.ysql.ast.YSQLJoin; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; +import sqlancer.yugabyte.ysql.gen.YSQLExpressionGenerator; + +public class YSQLTLPBase extends TernaryLogicPartitioningOracleBase + implements TestOracle { + + protected YSQLSchema s; + protected YSQLTables targetTables; + protected YSQLExpressionGenerator gen; + protected YSQLSelect select; + + public YSQLTLPBase(YSQLGlobalState state) { + super(state); + YSQLErrors.addCommonExpressionErrors(errors); + YSQLErrors.addCommonFetchErrors(errors); + } + + @Override + public void check() throws SQLException { + s = state.getSchema(); + targetTables = s.getRandomTableNonEmptyTables(); + List tables = targetTables.getTables(); + List joins = getJoinStatements(state, targetTables.getColumns(), tables); + generateSelectBase(tables, joins); + } + + public static List getJoinStatements(YSQLGlobalState globalState, List columns, + List tables) { + List joinStatements = new ArrayList<>(); + YSQLExpressionGenerator gen = new YSQLExpressionGenerator(globalState).setColumns(columns); + for (int i = 1; i < tables.size(); i++) { + YSQLExpression joinClause = gen.generateExpression(YSQLDataType.BOOLEAN); + YSQLTable table = Randomly.fromList(tables); + tables.remove(table); + YSQLJoin.YSQLJoinType options = YSQLJoin.YSQLJoinType.getRandom(); + YSQLJoin j = new YSQLJoin(new YSQLSelect.YSQLFromTable(table, Randomly.getBoolean()), joinClause, options); + joinStatements.add(j); + } + // JOIN subqueries + for (int i = 0; i < Randomly.smallNumber(); i++) { + YSQLTables subqueryTables = globalState.getSchema().getRandomTableNonEmptyTables(); + YSQLSelect.YSQLSubquery subquery = YSQLExpressionGenerator.createSubquery(globalState, + String.format("sub%d", i), subqueryTables); + YSQLExpression joinClause = gen.generateExpression(YSQLDataType.BOOLEAN); + YSQLJoin.YSQLJoinType options = YSQLJoin.YSQLJoinType.getRandom(); + YSQLJoin j = new YSQLJoin(subquery, joinClause, options); + joinStatements.add(j); + } + return joinStatements; + } + + protected void generateSelectBase(List tables, List joins) { + List tableList = tables.stream() + .map(t -> new YSQLSelect.YSQLFromTable(t, Randomly.getBoolean())).collect(Collectors.toList()); + gen = new YSQLExpressionGenerator(state).setColumns(targetTables.getColumns()); + initializeTernaryPredicateVariants(); + select = new YSQLSelect(); + select.setFetchColumns(generateFetchColumns()); + select.setFromList(tableList); + select.setWhereClause(null); + select.setJoinClauses(joins); + if (Randomly.getBoolean()) { + select.setForClause(YSQLSelect.ForClause.getRandom()); + } + } + + List generateFetchColumns() { + if (Randomly.getBooleanWithRatherLowProbability()) { + return Arrays.asList(new YSQLColumnValue(YSQLColumn.createDummy("*"), null)); + } + List fetchColumns = new ArrayList<>(); + List targetColumns = Randomly.nonEmptySubset(targetTables.getColumns()); + for (YSQLColumn c : targetColumns) { + fetchColumns.add(new YSQLColumnValue(c, null)); + } + return fetchColumns; + } + + @Override + protected ExpressionGenerator getGen() { + return gen; + } + +} diff --git a/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPHavingOracle.java b/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPHavingOracle.java new file mode 100644 index 000000000..32df70b11 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPHavingOracle.java @@ -0,0 +1,66 @@ +package sqlancer.yugabyte.ysql.oracle.tlp; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import sqlancer.ComparatorHelper; +import sqlancer.Randomly; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLDataType; +import sqlancer.yugabyte.ysql.YSQLVisitor; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; + +public class YSQLTLPHavingOracle extends YSQLTLPBase { + + public YSQLTLPHavingOracle(YSQLGlobalState state) { + super(state); + YSQLErrors.addGroupingErrors(errors); + } + + @Override + public void check() throws SQLException { + super.check(); + havingCheck(); + } + + @Override + List generateFetchColumns() { + List expressions = gen.allowAggregates(true).generateExpressions(Randomly.smallNumber() + 1); + gen.allowAggregates(false); + return expressions; + } + + protected void havingCheck() throws SQLException { + if (Randomly.getBoolean()) { + select.setWhereClause(gen.generateExpression(YSQLDataType.BOOLEAN)); + } + select.setGroupByExpressions(gen.generateExpressions(Randomly.smallNumber() + 1)); + select.setHavingClause(null); + String originalQueryString = YSQLVisitor.asString(select); + List resultSet = ComparatorHelper.getResultSetFirstColumnAsString(originalQueryString, errors, state); + + boolean orderBy = Randomly.getBoolean(); + if (orderBy) { + select.setOrderByClauses(gen.generateOrderBys()); + } + select.setHavingClause(predicate); + String firstQueryString = YSQLVisitor.asString(select); + select.setHavingClause(negatedPredicate); + String secondQueryString = YSQLVisitor.asString(select); + select.setHavingClause(isNullPredicate); + String thirdQueryString = YSQLVisitor.asString(select); + List combinedString = new ArrayList<>(); + List secondResultSet = ComparatorHelper.getCombinedResultSet(firstQueryString, secondQueryString, + thirdQueryString, combinedString, !orderBy, state, errors); + ComparatorHelper.assumeResultSetsAreEqual(resultSet, secondResultSet, originalQueryString, combinedString, + state); + } + + @Override + protected YSQLExpression generatePredicate() { + return gen.generateHavingClause(); + } + +} diff --git a/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPWhereOracle.java b/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPWhereOracle.java new file mode 100644 index 000000000..265586586 --- /dev/null +++ b/src/sqlancer/yugabyte/ysql/oracle/tlp/YSQLTLPWhereOracle.java @@ -0,0 +1,45 @@ +package sqlancer.yugabyte.ysql.oracle.tlp; + +import java.sql.SQLException; + +import sqlancer.Reproducer; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.yugabyte.ysql.YSQLErrors; +import sqlancer.yugabyte.ysql.YSQLGlobalState; +import sqlancer.yugabyte.ysql.YSQLSchema; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLColumn; +import sqlancer.yugabyte.ysql.YSQLSchema.YSQLTable; +import sqlancer.yugabyte.ysql.ast.YSQLExpression; +import sqlancer.yugabyte.ysql.ast.YSQLJoin; +import sqlancer.yugabyte.ysql.ast.YSQLSelect; +import sqlancer.yugabyte.ysql.gen.YSQLExpressionGenerator; + +public class YSQLTLPWhereOracle implements TestOracle { + + private final TLPWhereOracle oracle; + + public YSQLTLPWhereOracle(YSQLGlobalState state) { + YSQLExpressionGenerator gen = new YSQLExpressionGenerator(state); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(YSQLErrors.getCommonExpressionErrors()) + .with(YSQLErrors.getCommonFetchErrors()).build(); + + this.oracle = new TLPWhereOracle<>(state, gen, expectedErrors); + } + + @Override + public void check() throws SQLException { + oracle.check(); + } + + @Override + public String getLastQueryString() { + return oracle.getLastQueryString(); + } + + @Override + public Reproducer getLastReproducer() { + return oracle.getLastReproducer(); + } +} diff --git a/test/sqlancer/TestCommonSchema.java b/test/sqlancer/TestCommonSchema.java new file mode 100644 index 000000000..f3a2ee2ba --- /dev/null +++ b/test/sqlancer/TestCommonSchema.java @@ -0,0 +1,186 @@ +package sqlancer; + +import org.junit.jupiter.api.Test; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; + +import java.util.*; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestCommonSchema { + static class TestTable extends AbstractTable> { + TestTable(String name, List columns, List indexes, boolean isView) { + super(name, columns, indexes, isView); + } + + @Override + public long getNrRows(GlobalState globalState) { + return 0; + } + } + + static class TestTableColumn extends AbstractTableColumn { + TestTableColumn(String name, TestTable table, String type) { + super(name, table, type); + } + } + + static class TestSchema extends AbstractSchema, TestTable> { + TestSchema(List tables) { + super(tables); + } + } + + static class TestTables extends AbstractTables { + TestTables(List tables) { + super(tables); + } + } + + static class TestIndex extends TableIndex { + TestIndex(String name) { + super(name); + } + } + + private TestTable createTestTable(String name, List indexes, boolean isView, String... columns) { + List cols = Arrays.stream(columns).map(col -> new TestTableColumn(col, null, "VARCHAR")) + .collect(Collectors.toList()); + return new TestTable(name, cols, indexes, isView); + } + + private TestTableColumn createTestColumn(String name, TestTable table, String type) { + return new TestTableColumn(name, table, type); + } + + private TestSchema createTestSchema(TestTable... tables) { + return new TestSchema(Arrays.asList(tables)); + } + + private TestTables createTestTables(TestTable... tables) { + return new TestTables(new ArrayList(Arrays.asList(tables))); + } + + @Test + void testColumnManagement() { + TestTable table = createTestTable("products", Collections.emptyList(), false, "sku", "price"); + List columnNames = table.getColumns().stream().map(TestTableColumn::getName) + .collect(Collectors.toList()); + List columnTypes = table.getColumns().stream().map(TestTableColumn::getType) + .collect(Collectors.toList()); + TestTableColumn randomCol = table.getRandomColumn(); + + assertTrue(columnNames.containsAll(Set.of("sku", "price"))); + assertTrue(columnTypes.containsAll(Set.of("VARCHAR"))); + assertTrue(table.getColumns().contains(randomCol)); + } + + @Test + void testIndexManagement() { + TestIndex idx1 = new TestIndex("idx_sku"); + TestIndex idx2 = new TestIndex("idx_price"); + TestTable table = createTestTable("products", Arrays.asList(idx1, idx2), false, "sku", "price"); + TableIndex randomIndex = table.getRandomIndex(); + + assertTrue(table.hasIndexes()); + assertEquals(2, table.getIndexes().size()); + assertTrue(table.getIndexes().contains(randomIndex)); + } + + @Test + void testViewManagement() { + TestTable view1 = createTestTable("v1", Collections.emptyList(), true, "col1"); + TestTable view2 = createTestTable("v2", Collections.emptyList(), true, "col2"); + TestTable table = createTestTable("t1", Collections.emptyList(), false, "col3"); + TestSchema schema = createTestSchema(view1, view2, table); + + assertAll(() -> assertEquals(2, schema.getViews().size(), "Should detect 2 views"), + () -> assertEquals(1, schema.getDatabaseTablesWithoutViews().size(), "Should detect 1 normal table"), + () -> assertEquals("t1", schema.getDatabaseTablesWithoutViews().get(0).getName())); + } + + @Test + void testFreeColumnNameGeneration() { + TestTable table = createTestTable("users", Collections.emptyList(), false, "id", "name"); + Set generatedNames = new HashSet<>(); + + for (int i = 0; i < 100; i++) { + String newName = table.getFreeColumnName(); + assertTrue(generatedNames.add(newName), "Duplicate: " + newName); + + List newColumns = new ArrayList<>(table.getColumns()); + newColumns.add(new TestTableColumn(newName, table, "TEXT")); + table = new TestTable(table.getName(), newColumns, table.getIndexes(), table.isView()); + } + } + + @Test + void testObjectComparison() { + TestTable tableA = createTestTable("A", Collections.emptyList(), false, "x", "y"); + TestTable tableB = createTestTable("B", Collections.emptyList(), false, "b"); + + TestTableColumn colA1 = new TestTableColumn("x", tableA, "INT"); + TestTableColumn colA2 = new TestTableColumn("y", tableA, "INT"); + TestTableColumn colB1 = new TestTableColumn("b", tableB, "TEXT"); + + assertAll(() -> assertTrue(colA1.compareTo(colA2) < 0, "Columns should be ordered by name"), + () -> assertTrue(colA1.compareTo(colB1) < 0, "Columns should be ordered by name"), + () -> assertTrue(tableA.compareTo(tableB) > 0, "Tables should be ordered reverse-alphabetically"), + () -> assertEquals(0, tableA.compareTo(tableA), "Same table should be equal")); + } + + @Test + void testEquality() { + TestTable table1 = createTestTable("t1", Collections.emptyList(), false, "id"); + TestTable table2 = createTestTable("t2", Collections.emptyList(), false, "id"); + + TestTableColumn col1 = new TestTableColumn("id", table1, "INT"); + TestTableColumn col2 = new TestTableColumn("id", table1, "INT"); + TestTableColumn col3 = new TestTableColumn("id", table2, "INT"); + TestTableColumn col4 = new TestTableColumn("name", table1, "TEXT"); + + assertAll(() -> assertEquals(col1, col2, "Same table/column should be equal"), + () -> assertNotEquals(col1, col3, "Different tables should not be equal"), + () -> assertNotEquals(col1, col4, "Different columns should not be equal"), + () -> assertNotEquals(col1, "invalid_object", "Different types should not be equal")); + } + + @Test + void testBoundaryConditions() { + String longName = "a".repeat(256); + TestTableColumn longCol = new TestTableColumn(longName, null, "TEXT"); + assertEquals(longName, longCol.getName()); + + TestTableColumn col2 = createTestColumn("orphan", null, "UNKNOWN"); + assertEquals("orphan", col2.getFullQualifiedName()); + assertNull(col2.getTable()); + } + + @Test + void testTablesManagement() { + TestTable table1 = createTestTable("t1", Collections.emptyList(), false, "col1"); + TestTable table2 = createTestTable("t2", Collections.emptyList(), false, "col2"); + TestTable table3 = createTestTable("t3", Collections.emptyList(), false, "col3"); + + TestTables tables = createTestTables(table1, table2, table3); + assertEquals(3, tables.getSize(), "Should detect 3 tables"); + assertEquals(3, tables.getColumns().size(), "Should detect 3 columns"); + assertTrue(tables.isContained(table3), "Table3 shoule be contained"); + + TestTable table4 = createTestTable("t4", Collections.emptyList(), false, "col4"); + tables.addTable(table4); + assertEquals(4, tables.getSize(), "Should detect 4 tables"); + assertEquals(4, tables.getColumns().size(), "Should detect 4 columns"); + assertTrue(tables.isContained(table4), "Table4 should be contained"); + + tables.removeTable(table4); + assertEquals(3, tables.getSize(), "Should detect 3 tables"); + assertEquals(3, tables.getColumns().size(), "Should detect 3 columns"); + assertTrue(!tables.isContained(table4), "Table4 should not be contained"); + } +} diff --git a/test/sqlancer/TestComparatorHelper.java b/test/sqlancer/TestComparatorHelper.java index 6815b9258..e8b06388b 100644 --- a/test/sqlancer/TestComparatorHelper.java +++ b/test/sqlancer/TestComparatorHelper.java @@ -2,19 +2,37 @@ import static org.junit.jupiter.api.Assertions.assertThrowsExactly; +import java.sql.SQLException; import java.util.Arrays; import java.util.List; import org.junit.jupiter.api.Test; +import sqlancer.h2.H2Options; +import sqlancer.h2.H2Schema; + public class TestComparatorHelper { // TODO: Implement tests for the other ComparatorHelper methods + // TODO: create test state that not depends on specific database + final SQLGlobalState state = new SQLGlobalState() { + + @Override + protected H2Schema readSchema() throws SQLException { + return H2Schema.fromConnection(getConnection(), getDatabaseName()); + } + + @Override + public MainOptions getOptions() { + return new MainOptions(); + } + }; + @Test public void testAssumeResultSetsAreEqualWithEqualSets() { List r1 = Arrays.asList("a", "b", "c"); List r2 = Arrays.asList("a", "b", "c"); - ComparatorHelper.assumeResultSetsAreEqual(r1, r2, "", Arrays.asList(""), null); + ComparatorHelper.assumeResultSetsAreEqual(r1, r2, "", Arrays.asList(""), state); } @@ -26,7 +44,7 @@ public void testAssumeResultSetsAreEqualWithUnequalLengthSets() { // line occurs before AssertionError is thrown, but it's good enough as an indicator that one of the Exceptions // is raised assertThrowsExactly(NullPointerException.class, () -> { - ComparatorHelper.assumeResultSetsAreEqual(r1, r2, "", Arrays.asList(""), null); + ComparatorHelper.assumeResultSetsAreEqual(r1, r2, "", Arrays.asList(""), state); }); } @@ -38,7 +56,7 @@ public void testAssumeResultSetsAreEqualWithUnequalValueSets() { // line occurs before AssertionError is thrown, but it's good enough as an indicator that one of the Exceptions // is raised assertThrowsExactly(NullPointerException.class, () -> { - ComparatorHelper.assumeResultSetsAreEqual(r1, r2, "", Arrays.asList(""), null); + ComparatorHelper.assumeResultSetsAreEqual(r1, r2, "", Arrays.asList(""), state); }); } @@ -46,7 +64,7 @@ public void testAssumeResultSetsAreEqualWithUnequalValueSets() { public void testAssumeResultSetsAreEqualWithCanonicalizationRule() { List r1 = Arrays.asList("a", "b", "c"); List r2 = Arrays.asList("a", "b", "d"); - ComparatorHelper.assumeResultSetsAreEqual(r1, r2, "", Arrays.asList(""), null, (String s) -> { + ComparatorHelper.assumeResultSetsAreEqual(r1, r2, "", Arrays.asList(""), state, (String s) -> { return s.equals("d") ? "c" : s; }); } diff --git a/test/sqlancer/TestExpectedErrors.java b/test/sqlancer/TestExpectedErrors.java index 69a0f53e6..2a3b1a67e 100644 --- a/test/sqlancer/TestExpectedErrors.java +++ b/test/sqlancer/TestExpectedErrors.java @@ -1,5 +1,8 @@ package sqlancer; +import java.util.List; +import java.util.regex.Pattern; + import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -16,7 +19,7 @@ public void testEmpty() { } @Test - public void testSimple() { + public void testStringSimple() { ExpectedErrors errors = new ExpectedErrors(); errors.add("a"); errors.add("b"); @@ -25,16 +28,132 @@ public void testSimple() { assertTrue(errors.errorIsExpected("b")); assertTrue(errors.errorIsExpected("c")); assertTrue(errors.errorIsExpected("aa")); + assertFalse(errors.errorIsExpected("d")); + } + + @Test + public void testStringsSimple() { + ExpectedErrors errors = new ExpectedErrors(); + errors.addAll(List.of("a", "b", "c")); + assertTrue(errors.errorIsExpected("a")); + assertTrue(errors.errorIsExpected("b")); + assertTrue(errors.errorIsExpected("c")); + assertTrue(errors.errorIsExpected("aa")); assertFalse(errors.errorIsExpected("d")); + + } + + @Test + public void testRegexSimple() { + ExpectedErrors errors = new ExpectedErrors(); + errors.addRegex(Pattern.compile("a\\d")); + errors.addRegex(Pattern.compile("b\\D")); + errors.add("c"); + assertTrue(errors.errorIsExpected("a0")); + assertTrue(errors.errorIsExpected("bb")); + assertTrue(errors.errorIsExpected("c")); + assertFalse(errors.errorIsExpected("aa")); + + } + + @Test + public void testRegexesSimple() { + ExpectedErrors errors = new ExpectedErrors(); + errors.addAllRegexes(List.of(Pattern.compile("a\\d"), Pattern.compile("b\\D"))); + errors.add("c"); + assertTrue(errors.errorIsExpected("a0")); + assertTrue(errors.errorIsExpected("bb")); + assertTrue(errors.errorIsExpected("c")); + assertFalse(errors.errorIsExpected("aa")); + } + + @Test + public void testRegexStringSimple() { + ExpectedErrors errors = new ExpectedErrors(); + errors.addRegexString("a\\d"); + errors.addRegexString("b\\D"); + errors.add("c"); + assertTrue(errors.errorIsExpected("a0")); + assertTrue(errors.errorIsExpected("bb")); + assertTrue(errors.errorIsExpected("c")); + assertFalse(errors.errorIsExpected("aa")); + + } + + @Test + public void testRegexStrings() { + ExpectedErrors errors = new ExpectedErrors(); + errors.addAllRegexStrings(List.of("a\\d", "b\\D")); + errors.add("c"); + assertTrue(errors.errorIsExpected("a0")); + assertTrue(errors.errorIsExpected("bb")); + assertTrue(errors.errorIsExpected("c")); + assertFalse(errors.errorIsExpected("aa")); } @Test - public void testRealistic() { + public void testStringRealistic() { ExpectedErrors errors = new ExpectedErrors(); errors.add("violated"); assertTrue(errors.errorIsExpected("UNIQUE constraint was violated!")); assertTrue(errors.errorIsExpected("PRIMARY KEY constraint was violated!")); } + @Test + public void testRegexRealistic() { + ExpectedErrors errors = new ExpectedErrors(); + errors.addRegex(Pattern.compile(".violated.")); + assertTrue(errors.errorIsExpected("UNIQUE constraint was violated!")); + assertTrue(errors.errorIsExpected("PRIMARY KEY constraint was violated!")); + } + + @Test + public void testBuilder() { + ExpectedErrors errors = ExpectedErrors.newErrors().with("a", "b", "c").build(); + + assertTrue(errors.errorIsExpected("a")); + assertTrue(errors.errorIsExpected("b")); + assertTrue(errors.errorIsExpected("c")); + assertTrue(errors.errorIsExpected("aa")); + assertFalse(errors.errorIsExpected("d")); + + errors = ExpectedErrors.newErrors().with(List.of("a", "b", "c")).build(); + + assertTrue(errors.errorIsExpected("a")); + assertTrue(errors.errorIsExpected("b")); + assertTrue(errors.errorIsExpected("c")); + assertTrue(errors.errorIsExpected("aa")); + assertFalse(errors.errorIsExpected("d")); + + errors = ExpectedErrors.newErrors().withRegexString("a\\d", "b\\D").with("c").build(); + + assertTrue(errors.errorIsExpected("a0")); + assertTrue(errors.errorIsExpected("bb")); + assertTrue(errors.errorIsExpected("c")); + assertFalse(errors.errorIsExpected("aa")); + + errors = ExpectedErrors.newErrors().withRegexString(List.of("a\\d", "b\\D")).with("c").build(); + + assertTrue(errors.errorIsExpected("a0")); + assertTrue(errors.errorIsExpected("bb")); + assertTrue(errors.errorIsExpected("c")); + assertFalse(errors.errorIsExpected("aa")); + + errors = ExpectedErrors.newErrors().withRegex(Pattern.compile("a\\d"), Pattern.compile("b\\D")).with("c") + .build(); + + assertTrue(errors.errorIsExpected("a0")); + assertTrue(errors.errorIsExpected("bb")); + assertTrue(errors.errorIsExpected("c")); + assertFalse(errors.errorIsExpected("aa")); + + errors = ExpectedErrors.newErrors().withRegex(List.of(Pattern.compile("a\\d"), Pattern.compile("b\\D"))) + .with("c").build(); + + assertTrue(errors.errorIsExpected("a0")); + assertTrue(errors.errorIsExpected("bb")); + assertTrue(errors.errorIsExpected("c")); + assertFalse(errors.errorIsExpected("aa")); + } } diff --git a/test/sqlancer/TestLoggableFactory.java b/test/sqlancer/TestLoggableFactory.java new file mode 100644 index 000000000..2679cf740 --- /dev/null +++ b/test/sqlancer/TestLoggableFactory.java @@ -0,0 +1,17 @@ +package sqlancer; + +import org.junit.jupiter.api.Test; +import sqlancer.common.log.SQLLoggableFactory; +import sqlancer.common.query.SQLQueryAdapter; + +public class TestLoggableFactory { + + @Test + public void testLogCreateTable() { + String query = "CREATE TABLE t1 (c1 INT)"; + SQLLoggableFactory logger = new SQLLoggableFactory(); + SQLQueryAdapter queryAdapter = logger.getQueryForStateToReproduce(query); + assert (queryAdapter.couldAffectSchema()); + } + +} diff --git a/test/sqlancer/TestRandomly.java b/test/sqlancer/TestRandomly.java index 6be6b1bd4..8fbd97790 100644 --- a/test/sqlancer/TestRandomly.java +++ b/test/sqlancer/TestRandomly.java @@ -38,6 +38,7 @@ public void testSubset() { boolean encounteredStrictSubsetNonEmpty = false; Integer[] options = { 1, 2, 3 }; List optionList = new ArrayList<>(Arrays.asList(options)); + int i = 0; do { List subset = Randomly.subset(optionList); assertEquals(optionList.size(), 3); // check that the original set hasn't been modified @@ -49,7 +50,13 @@ public void testSubset() { } else { encounteredStrictSubsetNonEmpty = true; } - } while (!encounteredEmptySubset || !encounteredOriginalSet || !encounteredStrictSubsetNonEmpty); + } while (!encounteredEmptySubset || !encounteredOriginalSet || !encounteredStrictSubsetNonEmpty + || i++ < NR_MIN_RUNS); + + assertTrue(encounteredEmptySubset, "Empty subset was not encountered"); + assertTrue(encounteredOriginalSet, "Original set was not encountered"); + assertTrue(encounteredStrictSubsetNonEmpty, "Strict subset was not encountered"); + } @Test @@ -75,6 +82,10 @@ public void testString() { } } while (!encounteredInteger || !encounteredAscii || !encounteredNonAscii || !encounteredSpace || i++ < NR_MIN_RUNS); + assertTrue(encounteredInteger, "Integer was not encountered"); + assertTrue(encounteredAscii, "Ascii was not encountered"); + assertTrue(encounteredNonAscii, "Non ascii was not encountered"); + assertTrue(encounteredSpace, "Space was not encountered"); } @Test // TODO: also generate and check for NaN @@ -84,6 +95,7 @@ public void testDouble() { boolean encounteredPositive = false; boolean encounteredNegative = false; boolean encounteredInfinity = false; + int i = 0; do { double doubleVal = r.getDouble(); if (doubleVal == 0) { @@ -97,7 +109,12 @@ public void testDouble() { } else { fail(String.valueOf(doubleVal)); } - } while (!encounteredZero || !encounteredPositive || !encounteredNegative || !encounteredInfinity); + } while (!encounteredZero || !encounteredPositive || !encounteredNegative || !encounteredInfinity + || i++ < NR_MIN_RUNS); + assertTrue(encounteredZero, "Zero was not encountered"); + assertTrue(encounteredPositive, "Positive was not encountered"); + assertTrue(encounteredNegative, "Negative was not encountered"); + assertTrue(encounteredInfinity, "Infinity was not encountered"); } @Test @@ -123,6 +140,8 @@ public void testNonZeroInteger() { encounteredNegative = true; } } while (!encounteredPositive || !encounteredNegative || i++ < NR_MIN_RUNS); + assertTrue(encounteredPositive, "Positive integer was not encountered"); + assertTrue(encounteredNegative, "Negative integer was not encountered"); } @Test @@ -140,6 +159,8 @@ public void testPositiveInteger() { encounteredMaxValue = true; } } while (!encounteredZero || !encounteredMaxValue || i++ < NR_MIN_RUNS); + assertTrue(encounteredZero, "Zero was not encountered"); + assertTrue(encounteredMaxValue, "Max value was not encountered"); } @Test @@ -159,6 +180,9 @@ public void testBytes() { encounteredMax = true; } } while (!encounteredAllZeroes || !encounteredMax || !encounteredZeroLength || i++ < NR_MIN_RUNS); + assertTrue(encounteredAllZeroes, "All zeroes were not encountered"); + assertTrue(encounteredMax, "Max value was not encountered"); + assertTrue(encounteredZeroLength, "Zero length was not encountered"); } @Test @@ -234,4 +258,63 @@ private List getRandomValueList(Randomly r) { return values; } + @Test + public void testGetPercentage() { + for (int i = 0; i < NR_MIN_RUNS; i++) { + double percentage = Randomly.getPercentage(); + assertTrue(percentage >= 0.0); + assertTrue(percentage <= 1.0); + } + } + + @Test + public void testGetChar() { + Randomly r = new Randomly(); + boolean encounteredAlphabetic = false; + boolean encounteredNumeric = false; + boolean encounteredSpecial = false; + int i = 0; + do { + String c = r.getChar(); + assertEquals(1, c.length()); + if (Character.isAlphabetic(c.charAt(0))) { + encounteredAlphabetic = true; + } else if (Character.isDigit(c.charAt(0))) { + encounteredNumeric = true; + } else { + encounteredSpecial = true; + } + } while (!encounteredAlphabetic || !encounteredNumeric || !encounteredSpecial || i++ < NR_MIN_RUNS); + assertTrue(encounteredAlphabetic, "Never encounter an alphabetic character."); + assertTrue(encounteredNumeric, "Never encounter a numeric character."); + assertTrue(encounteredSpecial, "Never encounter a special character."); + } + + @Test + public void testGetAlphabeticChar() { + Randomly r = new Randomly(); + for (int i = 0; i < NR_MIN_RUNS; i++) { + String c = r.getAlphabeticChar(); + assertEquals(1, c.length()); + assertTrue(Character.isAlphabetic(c.charAt(0))); + } + } + + @Test + public void testGetBooleanWithSmallProbability() { + int trueCount = 0; + int totalRuns = NR_MIN_RUNS; + + for (int i = 0; i < totalRuns; i++) { + if (Randomly.getBooleanWithSmallProbability()) { + trueCount++; + } + } + + double trueRatio = (double) trueCount / totalRuns; + assertTrue(trueRatio > 0.005); + assertTrue(trueRatio < 0.015); + + } + } diff --git a/test/sqlancer/TestStateToReproduce.java b/test/sqlancer/TestStateToReproduce.java new file mode 100644 index 000000000..6fcdc22f1 --- /dev/null +++ b/test/sqlancer/TestStateToReproduce.java @@ -0,0 +1,109 @@ +package sqlancer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.Query; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.sqlite3.SQLite3Provider; + +public class TestStateToReproduce { + + @TempDir + Path tempDir; + + @Test + public void testBasicFields() throws IOException { + SQLite3Provider provider = new SQLite3Provider(); + StateToReproduce state = new StateToReproduce("test_db", provider); + state.databaseVersion = "3.36.0"; + state.seedValue = 12345L; + state.exception = "Test exception message"; + + Path file = tempDir.resolve("test_basic.ser"); + state.serialize(file); + StateToReproduce result = StateToReproduce.deserialize(file); + + assertEquals(state.getDatabaseName(), result.getDatabaseName()); + assertEquals(state.getDatabaseVersion(), result.getDatabaseVersion()); + assertEquals(state.getSeedValue(), result.getSeedValue()); + assertEquals(state.getException(), result.getException()); + } + + @Test + public void testStatements() throws IOException { + SQLite3Provider provider = new SQLite3Provider(); + StateToReproduce state = new StateToReproduce("test_statements", provider); + List> statements = new ArrayList<>(); + + ExpectedErrors errors1 = new ExpectedErrors(); + errors1.add("syntax error"); + errors1.add("table already exists"); + statements.add(new SQLQueryAdapter("CREATE TABLE test (id INTEGER);", errors1)); + + ExpectedErrors errors2 = new ExpectedErrors(); + errors2.add("constraint failed"); + statements.add(new SQLQueryAdapter("INSERT INTO test VALUES (1);", errors2)); + + statements.add(new SQLQueryAdapter("SELECT * FROM test;", new ExpectedErrors())); + state.setStatements(statements); + + Path file = tempDir.resolve("test_statements.ser"); + state.serialize(file); + StateToReproduce result = StateToReproduce.deserialize(file); + + List> resultStatements = result.getStatements(); + assertEquals(3, resultStatements.size()); + + Query q1 = resultStatements.get(0); + Query q2 = resultStatements.get(1); + Query q3 = resultStatements.get(2); + + assertEquals("CREATE TABLE test (id INTEGER);", q1.getLogString()); + assertEquals("INSERT INTO test VALUES (1);", q2.getLogString()); + assertEquals("SELECT * FROM test;", q3.getLogString()); + + ExpectedErrors e1 = q1.getExpectedErrors(); + ExpectedErrors e2 = q2.getExpectedErrors(); + ExpectedErrors e3 = q3.getExpectedErrors(); + + assertTrue(e1.errorIsExpected("syntax error")); + assertTrue(e1.errorIsExpected("table already exists")); + assertFalse(e1.errorIsExpected("constraint failed")); + + assertFalse(e2.errorIsExpected("syntax error")); + assertTrue(e2.errorIsExpected("constraint failed")); + + assertFalse(e3.errorIsExpected("syntax error")); + assertFalse(e3.errorIsExpected("constraint failed")); + } + + @Test + public void testDatabaseProvider() throws IOException { + SQLite3Provider provider = new SQLite3Provider(); + StateToReproduce state = new StateToReproduce("test_provider", provider); + state.logStatement("CREATE TABLE test (id INTEGER);"); + + Path file = tempDir.resolve("test_provider.ser"); + state.serialize(file); + StateToReproduce result = StateToReproduce.deserialize(file); + + // Verify databaseProvider is correctly deserialized + assertEquals("sqlite3", result.getDatabaseProvider().getDBMSName()); + + // Verify databaseProvider functionality by testing logStatement + result.logStatement("INSERT INTO test VALUES (1);"); + assertEquals(2, result.getStatements().size()); + assertEquals("INSERT INTO test VALUES (1);", result.getStatements().get(1).getLogString()); + } +} \ No newline at end of file diff --git a/test/sqlancer/clickhouse/ast/ClickHouseBinaryComparisonOperationTest.java b/test/sqlancer/clickhouse/ast/ClickHouseBinaryComparisonOperationTest.java index d9dd89a5f..97ca1e282 100644 --- a/test/sqlancer/clickhouse/ast/ClickHouseBinaryComparisonOperationTest.java +++ b/test/sqlancer/clickhouse/ast/ClickHouseBinaryComparisonOperationTest.java @@ -5,7 +5,8 @@ import java.util.Arrays; import java.util.stream.Collectors; -import ru.yandex.clickhouse.domain.ClickHouseDataType; +import com.clickhouse.client.ClickHouseDataType; +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -13,117 +14,117 @@ class ClickHouseBinaryComparisonOperationTest { @Test void getExpectedValueTrueEqualsTrue() { - ClickHouseConstant trueConst = ClickHouseConstant.createTrue(); - ClickHouseConstant equals = trueConst.applyEquals(ClickHouseConstant.createTrue()); - assertEquals(equals.asInt(), 1); + ClickHouseConstant trueConst = ClickHouseCreateConstant.createTrue(); + ClickHouseConstant equals = trueConst.applyEquals(ClickHouseCreateConstant.createTrue()); + assertEquals(true, equals.asBooleanNotNull()); } @Test void getExpectedValueTrueNotEqualsFalse() { - ClickHouseConstant trueConst = ClickHouseConstant.createTrue(); - ClickHouseConstant falseConst = ClickHouseConstant.createFalse(); - ClickHouseConstant equals = trueConst.applyEquals(ClickHouseConstant.createFalse()); - ClickHouseConstant equalsFalse = falseConst.applyEquals(ClickHouseConstant.createTrue()); - assertEquals(equals.asInt(), 0); - assertEquals(equalsFalse.asInt(), 0); + ClickHouseConstant trueConst = ClickHouseCreateConstant.createTrue(); + ClickHouseConstant falseConst = ClickHouseCreateConstant.createFalse(); + ClickHouseConstant equals = trueConst.applyEquals(ClickHouseCreateConstant.createFalse()); + ClickHouseConstant equalsFalse = falseConst.applyEquals(ClickHouseCreateConstant.createTrue()); + assertEquals(false, equals.asBooleanNotNull()); + assertEquals(false, equalsFalse.asBooleanNotNull()); } @Test void getExpectedValueFloat64EqualsFloat64() { - ClickHouseConstant oneConst = ClickHouseConstant.createFloat64Constant(1); - ClickHouseConstant oneFConst = ClickHouseConstant.createFloat64Constant(1.0); - ClickHouseConstant zeroConst = ClickHouseConstant.createFloat64Constant(0); - ClickHouseConstant zeroFConst = ClickHouseConstant.createFloat64Constant(0.0); - - assertEquals(oneConst.applyEquals(oneConst).asInt(), 1); - assertEquals(oneFConst.applyEquals(oneFConst).asInt(), 1); - assertEquals(oneConst.applyEquals(oneFConst).asInt(), 1); - - assertEquals(oneConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(oneFConst.applyEquals(zeroFConst).asInt(), 0); - assertEquals(zeroConst.applyEquals(zeroFConst).asInt(), 1); - assertEquals(zeroFConst.applyEquals(zeroConst).asInt(), 1); + ClickHouseConstant oneConst = ClickHouseCreateConstant.createFloat64Constant(1); + ClickHouseConstant oneFConst = ClickHouseCreateConstant.createFloat64Constant(1.0); + ClickHouseConstant zeroConst = ClickHouseCreateConstant.createFloat64Constant(0); + ClickHouseConstant zeroFConst = ClickHouseCreateConstant.createFloat64Constant(0.0); + + assertEquals(true, oneConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(true, oneFConst.applyEquals(oneFConst).asBooleanNotNull()); + assertEquals(true, oneConst.applyEquals(oneFConst).asBooleanNotNull()); + + assertEquals(false, oneConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(false, oneFConst.applyEquals(zeroFConst).asBooleanNotNull()); + assertEquals(true, zeroConst.applyEquals(zeroFConst).asBooleanNotNull()); + assertEquals(true, zeroFConst.applyEquals(zeroConst).asBooleanNotNull()); } @Test void getExpectedValueInt32EqualsBool() { - ClickHouseConstant trueConst = ClickHouseConstant.createTrue(); - ClickHouseConstant falseConst = ClickHouseConstant.createFalse(); - ClickHouseConstant oneConst = ClickHouseConstant.createInt32Constant(1); - ClickHouseConstant zeroConst = ClickHouseConstant.createInt32Constant(0); - ClickHouseConstant negativeConst = ClickHouseConstant.createInt32Constant(-100); - ClickHouseConstant positiveConst = ClickHouseConstant.createInt32Constant(10000); - - assertEquals(trueConst.applyEquals(oneConst).asInt(), 1); - assertEquals(oneConst.applyEquals(oneConst).asInt(), 1); - assertEquals(falseConst.applyEquals(oneConst).asInt(), 0); - - assertEquals(trueConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(oneConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(falseConst.applyEquals(zeroConst).asInt(), 1); - - assertEquals(negativeConst.applyEquals(oneConst).asInt(), 0); - assertEquals(negativeConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(negativeConst.applyEquals(trueConst).asInt(), 0); - assertEquals(negativeConst.applyEquals(falseConst).asInt(), 0); - - assertEquals(positiveConst.applyEquals(oneConst).asInt(), 0); - assertEquals(positiveConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(positiveConst.applyEquals(trueConst).asInt(), 0); - assertEquals(positiveConst.applyEquals(falseConst).asInt(), 0); + ClickHouseConstant trueConst = ClickHouseCreateConstant.createTrue(); + ClickHouseConstant falseConst = ClickHouseCreateConstant.createFalse(); + ClickHouseConstant oneConst = ClickHouseCreateConstant.createInt32Constant(1); + ClickHouseConstant zeroConst = ClickHouseCreateConstant.createInt32Constant(0); + ClickHouseConstant negativeConst = ClickHouseCreateConstant.createInt32Constant(-100); + ClickHouseConstant positiveConst = ClickHouseCreateConstant.createInt32Constant(10000); + + assertEquals(true, trueConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(true, oneConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(false, falseConst.applyEquals(oneConst).asBooleanNotNull()); + + assertEquals(false, trueConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(false, oneConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(true, falseConst.applyEquals(zeroConst).asBooleanNotNull()); + + assertEquals(false, negativeConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(false, negativeConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(false, negativeConst.applyEquals(trueConst).asBooleanNotNull()); + assertEquals(false, negativeConst.applyEquals(falseConst).asBooleanNotNull()); + + assertEquals(false, positiveConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(false, positiveConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(false, positiveConst.applyEquals(trueConst).asBooleanNotNull()); + assertEquals(false, positiveConst.applyEquals(falseConst).asBooleanNotNull()); } @Test void getExpectedValueIntEqualsInt() { - ClickHouseConstant trueConst = ClickHouseConstant.createTrue(); - ClickHouseConstant falseConst = ClickHouseConstant.createFalse(); + ClickHouseConstant trueConst = ClickHouseCreateConstant.createTrue(); + ClickHouseConstant falseConst = ClickHouseCreateConstant.createFalse(); for (ClickHouseDataType type : Arrays. stream(ClickHouseDataType.values()) .filter((dt) -> dt.name().contains("Int") && !dt.name().contains("Interval")) .collect(Collectors.toList())) { - ClickHouseConstant oneConst = ClickHouseConstant.createIntConstant(type, 1); - ClickHouseConstant zeroConst = ClickHouseConstant.createIntConstant(type, 0); - ClickHouseConstant negativeConst = ClickHouseConstant.createIntConstant(type, -100); - ClickHouseConstant positiveConst = ClickHouseConstant.createIntConstant(type, 10000); - - assertEquals(trueConst.applyEquals(oneConst).asInt(), 1); - assertEquals(oneConst.applyEquals(oneConst).asInt(), 1); - assertEquals(falseConst.applyEquals(oneConst).asInt(), 0); - - assertEquals(trueConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(oneConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(falseConst.applyEquals(zeroConst).asInt(), 1); - - assertEquals(negativeConst.applyEquals(oneConst).asInt(), 0); - assertEquals(negativeConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(negativeConst.applyEquals(trueConst).asInt(), 0); - assertEquals(negativeConst.applyEquals(falseConst).asInt(), 0); - - assertEquals(positiveConst.applyEquals(oneConst).asInt(), 0); - assertEquals(positiveConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(positiveConst.applyEquals(trueConst).asInt(), 0); - assertEquals(positiveConst.applyEquals(falseConst).asInt(), 0); + ClickHouseConstant oneConst = ClickHouseCreateConstant.createIntConstant(type, 1); + ClickHouseConstant zeroConst = ClickHouseCreateConstant.createIntConstant(type, 0); + ClickHouseConstant negativeConst = ClickHouseCreateConstant.createIntConstant(type, -100); + ClickHouseConstant positiveConst = ClickHouseCreateConstant.createIntConstant(type, 10000); + + assertEquals(true, trueConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(true, oneConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(false, falseConst.applyEquals(oneConst).asBooleanNotNull()); + + assertEquals(false, trueConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(false, oneConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(true, falseConst.applyEquals(zeroConst).asBooleanNotNull()); + + assertEquals(false, negativeConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(false, negativeConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(false, negativeConst.applyEquals(trueConst).asBooleanNotNull()); + assertEquals(false, negativeConst.applyEquals(falseConst).asBooleanNotNull()); + + assertEquals(false, positiveConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(false, positiveConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(false, positiveConst.applyEquals(trueConst).asBooleanNotNull()); + assertEquals(false, positiveConst.applyEquals(falseConst).asBooleanNotNull()); } } @Test void getExpectedValueInt32EqualsFloat64() { - ClickHouseConstant float64OneConst = ClickHouseConstant.createFloat64Constant(1.0); - ClickHouseConstant float64ZeroConst = ClickHouseConstant.createFloat64Constant(0.0); - ClickHouseConstant oneConst = ClickHouseConstant.createInt32Constant(1); - ClickHouseConstant zeroConst = ClickHouseConstant.createInt32Constant(0); - ClickHouseConstant negativeConst = ClickHouseConstant.createInt32Constant(-100); - ClickHouseConstant positiveConst = ClickHouseConstant.createInt32Constant(10000); + ClickHouseConstant float64OneConst = ClickHouseCreateConstant.createFloat64Constant(1.0); + ClickHouseConstant float64ZeroConst = ClickHouseCreateConstant.createFloat64Constant(0.0); + ClickHouseConstant oneConst = ClickHouseCreateConstant.createInt32Constant(1); + ClickHouseConstant zeroConst = ClickHouseCreateConstant.createInt32Constant(0); + ClickHouseConstant negativeConst = ClickHouseCreateConstant.createInt32Constant(-100); + ClickHouseConstant positiveConst = ClickHouseCreateConstant.createInt32Constant(10000); - assertEquals(float64OneConst.applyEquals(oneConst).asInt(), 1); - assertEquals(float64ZeroConst.applyEquals(oneConst).asInt(), 0); + assertEquals(true, float64OneConst.applyEquals(oneConst).asBooleanNotNull()); + assertEquals(false, float64ZeroConst.applyEquals(oneConst).asBooleanNotNull()); - assertEquals(float64OneConst.applyEquals(zeroConst).asInt(), 0); - assertEquals(float64ZeroConst.applyEquals(zeroConst).asInt(), 1); + assertEquals(false, float64OneConst.applyEquals(zeroConst).asBooleanNotNull()); + assertEquals(true, float64ZeroConst.applyEquals(zeroConst).asBooleanNotNull()); - assertEquals(negativeConst.applyEquals(float64OneConst).asInt(), 0); - assertEquals(negativeConst.applyEquals(float64ZeroConst).asInt(), 0); + assertEquals(false, negativeConst.applyEquals(float64OneConst).asBooleanNotNull()); + assertEquals(false, negativeConst.applyEquals(float64ZeroConst).asBooleanNotNull()); - assertEquals(positiveConst.applyEquals(float64OneConst).asInt(), 0); - assertEquals(positiveConst.applyEquals(float64ZeroConst).asInt(), 0); + assertEquals(false, positiveConst.applyEquals(float64OneConst).asBooleanNotNull()); + assertEquals(false, positiveConst.applyEquals(float64ZeroConst).asBooleanNotNull()); } } diff --git a/test/sqlancer/clickhouse/ast/ClickHouseOperatorsVisitorTest.java b/test/sqlancer/clickhouse/ast/ClickHouseOperatorsVisitorTest.java new file mode 100644 index 000000000..fc786a625 --- /dev/null +++ b/test/sqlancer/clickhouse/ast/ClickHouseOperatorsVisitorTest.java @@ -0,0 +1,117 @@ +package sqlancer.clickhouse.ast; + +import org.junit.jupiter.api.Test; +import sqlancer.clickhouse.ClickHouseSchema; +import sqlancer.clickhouse.ClickHouseVisitor; +import sqlancer.clickhouse.ast.constant.ClickHouseCreateConstant; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ClickHouseOperatorsVisitorTest { + + @Test + void selectUnaryNot() { + ClickHouseConstant trueConst = ClickHouseCreateConstant.createBoolean(true); + ClickHouseExpression notTrue = new ClickHouseUnaryPrefixOperation(trueConst, + ClickHouseUnaryPrefixOperation.ClickHouseUnaryPrefixOperator.NOT); + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(notTrue)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT (NOT (true))"; + assertEquals(answer, result); + } + + @Test + void selectUnaryMinus() { + ClickHouseConstant fiveConst = ClickHouseCreateConstant.createUInt32Constant(5); + ClickHouseExpression minusFive = new ClickHouseUnaryPrefixOperation(fiveConst, + ClickHouseUnaryPrefixOperation.ClickHouseUnaryPrefixOperator.MINUS); + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(minusFive)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT (- (5))"; + assertEquals(answer, result); + } + + @Test + void selectUnaryExp() { + ClickHouseConstant tenConst = ClickHouseCreateConstant.createInt32Constant(10); + ClickHouseExpression minusFive = new ClickHouseUnaryFunctionOperation(tenConst, + ClickHouseUnaryFunctionOperation.ClickHouseUnaryFunctionOperator.EXP); + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(minusFive)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT (exp (10))"; + assertEquals(answer, result); + } + + @Test + void selectBinaryPlus() { + ClickHouseConstant dConst = ClickHouseCreateConstant.createFloat32Constant((float) -1.1); + ClickHouseConstant tenConst = ClickHouseCreateConstant.createInt32Constant(10); + ClickHouseExpression expr = new ClickHouseBinaryArithmeticOperation(dConst, tenConst, + ClickHouseBinaryArithmeticOperation.ClickHouseBinaryArithmeticOperator.ADD); + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(expr)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT ((-1.1)+(10))"; + assertEquals(answer, result); + } + + @Test + void selectBinaryPow() { + ClickHouseConstant threeConst = ClickHouseCreateConstant.createInt8Constant(3); + ClickHouseConstant tenConst = ClickHouseCreateConstant.createInt32Constant(10); + ClickHouseExpression expr = new ClickHouseBinaryFunctionOperation(threeConst, tenConst, + ClickHouseBinaryFunctionOperation.ClickHouseBinaryFunctionOperator.POW); + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(expr)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT pow(3,10)"; + assertEquals(answer, result); + } + + @Test + void selectBinaryLCM() { + ClickHouseConstant aConst = ClickHouseCreateConstant.createInt8Constant(100); + ClickHouseConstant bConst = ClickHouseCreateConstant.createInt32Constant(-100); + ClickHouseExpression expr = new ClickHouseBinaryFunctionOperation(aConst, bConst, + ClickHouseBinaryFunctionOperation.ClickHouseBinaryFunctionOperator.LCM); + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(expr)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT lcm(100,-100)"; + assertEquals(answer, result); + } + + @Test + void selectBinaryDivCol() { + ClickHouseColumnReference a = new ClickHouseColumnReference(new ClickHouseSchema.ClickHouseColumn("a", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null), null, null); + ClickHouseColumnReference b = new ClickHouseColumnReference(new ClickHouseSchema.ClickHouseColumn("b", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null), null, null); + ClickHouseExpression expr = new ClickHouseBinaryFunctionOperation(a, b, + ClickHouseBinaryFunctionOperation.ClickHouseBinaryFunctionOperator.INT_DIV); + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(expr)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT intDiv(a,b)"; + assertEquals(answer, result); + } + + @Test + void selectBinaryComp() { + ClickHouseConstant aConst = ClickHouseCreateConstant.createInt8Constant(10); + ClickHouseConstant bConst = ClickHouseCreateConstant.createInt32Constant(100); + ClickHouseExpression expr = new ClickHouseBinaryComparisonOperation(aConst, bConst, + ClickHouseBinaryComparisonOperation.ClickHouseBinaryComparisonOperator.GREATER); + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(expr)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT ((10)>(100))"; + assertEquals(answer, result); + } + +} diff --git a/test/sqlancer/clickhouse/ast/ClickHouseToStringVisitorTest.java b/test/sqlancer/clickhouse/ast/ClickHouseToStringVisitorTest.java new file mode 100644 index 000000000..ba9628037 --- /dev/null +++ b/test/sqlancer/clickhouse/ast/ClickHouseToStringVisitorTest.java @@ -0,0 +1,344 @@ +package sqlancer.clickhouse.ast; + +import org.junit.jupiter.api.Test; +import sqlancer.clickhouse.ClickHouseSchema; +import sqlancer.clickhouse.ClickHouseVisitor; +import sqlancer.clickhouse.ast.constant.ClickHouseInt8Constant; +import sqlancer.common.schema.TableIndex; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class ClickHouseToStringVisitorTest { + + @Test + void select1Test() { + ClickHouseConstant oneConst = new ClickHouseInt8Constant(1); + ClickHouseSelect selectOne = new ClickHouseSelect(); + selectOne.setFetchColumns(Arrays.asList(oneConst)); + String result = ClickHouseVisitor.asString(selectOne); + String answer = "SELECT 1"; + assertEquals(answer, result); + } + + @Test + void select1asATest() { + ClickHouseAliasOperation oneConstAsA = new ClickHouseAliasOperation(new ClickHouseInt8Constant(1), "a"); + ClickHouseSelect selectOne = new ClickHouseSelect(); + selectOne.setFetchColumns(Arrays.asList(oneConstAsA)); + String result = ClickHouseVisitor.asString(selectOne); + String answer = "SELECT 1 AS `a`"; + assertEquals(answer, result); + } + + @Test + void selectATest() { + List empty_col_list = Collections.emptyList(); + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseTable table = new ClickHouseSchema.ClickHouseTable("t", empty_col_list, indexes, + false); + ClickHouseTableReference table_ref = new ClickHouseTableReference(table, null); + ClickHouseSchema.ClickHouseColumn a_col = new ClickHouseSchema.ClickHouseColumn("a", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + a_col.setTable(table); + ClickHouseColumnReference a_ref = a_col.asColumnReference(null); + ClickHouseSelect selectA = new ClickHouseSelect(); + selectA.setFetchColumns(Arrays.asList(a_ref)); + selectA.setFromClause(table_ref); + String result = ClickHouseVisitor.asString(selectA); + String answer = "SELECT t.a FROM t"; + assertEquals(answer, result); + } + + @Test + void selectAasBTest() { + List empty_col_list = Collections.emptyList(); + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseTable table = new ClickHouseSchema.ClickHouseTable("t", empty_col_list, indexes, + false); + ClickHouseTableReference table_ref = new ClickHouseTableReference(table, null); + ClickHouseSchema.ClickHouseColumn a_col = new ClickHouseSchema.ClickHouseColumn("a", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + a_col.setTable(table); + ClickHouseColumnReference a_ref = a_col.asColumnReference(null); + ClickHouseAliasOperation b = new ClickHouseAliasOperation(a_ref, "b"); + ClickHouseColumnReference b_ref = new ClickHouseColumnReference(b); + ClickHouseSelect selectA = new ClickHouseSelect(); + selectA.setFetchColumns(Arrays.asList(b, b_ref)); + selectA.setFromClause(table_ref); + String result = ClickHouseVisitor.asString(selectA); + String answer = "SELECT t.a AS `b`, b FROM t"; + assertEquals(answer, result); + } + + @Test + void selectABTest() { + List empty_col_list = Collections.emptyList(); + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseTable table = new ClickHouseSchema.ClickHouseTable("t", empty_col_list, indexes, + false); + ClickHouseTableReference table_ref = new ClickHouseTableReference(table, null); + ClickHouseSchema.ClickHouseColumn a_col = new ClickHouseSchema.ClickHouseColumn("a", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + ClickHouseSchema.ClickHouseColumn b_col = new ClickHouseSchema.ClickHouseColumn("b", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + a_col.setTable(table); + b_col.setTable(table); + ClickHouseColumnReference a_ref = a_col.asColumnReference(null); + ClickHouseColumnReference b_ref = b_col.asColumnReference(null); + ClickHouseSelect selectAB = new ClickHouseSelect(); + selectAB.setFetchColumns(Arrays.asList(a_ref, b_ref)); + selectAB.setFromClause(table_ref); + String result = ClickHouseVisitor.asString(selectAB); + String answer = "SELECT t.a, t.b FROM t"; + assertEquals(answer, result); + } + + @Test + void selectWhereAGreaterBTest() { + List empty_col_list = Collections.emptyList(); + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseTable table = new ClickHouseSchema.ClickHouseTable("t", empty_col_list, indexes, + false); + ClickHouseTableReference table_ref = new ClickHouseTableReference(table, null); + ClickHouseSchema.ClickHouseColumn a_col = new ClickHouseSchema.ClickHouseColumn("a", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + ClickHouseSchema.ClickHouseColumn b_col = new ClickHouseSchema.ClickHouseColumn("b", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + a_col.setTable(table); + b_col.setTable(table); + ClickHouseColumnReference a_ref = a_col.asColumnReference(null); + ClickHouseColumnReference b_ref = b_col.asColumnReference(null); + ClickHouseSelect selectAB = new ClickHouseSelect(); + selectAB.setFetchColumns(Arrays.asList(a_ref, b_ref)); + selectAB.setFromClause(table_ref); + selectAB.setWhereClause(new ClickHouseBinaryComparisonOperation(a_ref, b_ref, + ClickHouseBinaryComparisonOperation.ClickHouseBinaryComparisonOperator.GREATER)); + String result = ClickHouseVisitor.asString(selectAB); + String answer = "SELECT t.a, t.b FROM t WHERE ((t.a)>(t.b))"; + assertEquals(answer, result); + } + + @Test + void selectWhereAGreaterConstTest() { + List empty_col_list = Collections.emptyList(); + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseTable table = new ClickHouseSchema.ClickHouseTable("t", empty_col_list, indexes, + false); + ClickHouseTableReference table_ref = new ClickHouseTableReference(table, null); + ClickHouseSchema.ClickHouseColumn a_col = new ClickHouseSchema.ClickHouseColumn("a", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + ClickHouseSchema.ClickHouseColumn b_col = new ClickHouseSchema.ClickHouseColumn("b", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + a_col.setTable(table); + b_col.setTable(table); + ClickHouseColumnReference a_ref = a_col.asColumnReference(null); + ClickHouseColumnReference b_ref = b_col.asColumnReference(null); + ClickHouseSelect selectAB = new ClickHouseSelect(); + selectAB.setFetchColumns(Arrays.asList(a_ref, b_ref)); + selectAB.setFromClause(table_ref); + ClickHouseConstant c_const = new ClickHouseInt8Constant(1); + selectAB.setWhereClause(new ClickHouseBinaryComparisonOperation(a_ref, c_const, + ClickHouseBinaryComparisonOperation.ClickHouseBinaryComparisonOperator.GREATER)); + String result = ClickHouseVisitor.asString(selectAB); + String answer = "SELECT t.a, t.b FROM t WHERE ((t.a)>(1))"; + assertEquals(answer, result); + } + + @Test + void selectSumAGroupByBTest() { + List empty_col_list = Collections.emptyList(); + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseTable table = new ClickHouseSchema.ClickHouseTable("t", empty_col_list, indexes, + false); + ClickHouseTableReference table_ref = new ClickHouseTableReference(table, null); + ClickHouseSchema.ClickHouseColumn a_col = new ClickHouseSchema.ClickHouseColumn("a", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + ClickHouseSchema.ClickHouseColumn b_col = new ClickHouseSchema.ClickHouseColumn("b", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table); + a_col.setTable(table); + b_col.setTable(table); + ClickHouseColumnReference a_ref = a_col.asColumnReference(null); + ClickHouseColumnReference b_ref = b_col.asColumnReference(null); + ClickHouseSelect selectAB = new ClickHouseSelect(); + ClickHouseAggregate sum_a = new ClickHouseAggregate(a_ref, ClickHouseAggregate.ClickHouseAggregateFunction.SUM); + selectAB.setFetchColumns(Arrays.asList(sum_a)); + selectAB.setFromClause(table_ref); + selectAB.setGroupByClause(Arrays.asList(b_ref)); + String result = ClickHouseVisitor.asString(selectAB); + String answer = "SELECT SUM(t.a) FROM t GROUP BY t.b"; + assertEquals(answer, result); + } + + @Test + void selectCrossJoinTest() { + List empty_col_list = Collections.emptyList(); + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseTable table1 = new ClickHouseSchema.ClickHouseTable("t1", empty_col_list, indexes, + false); + ClickHouseSchema.ClickHouseTable table2 = new ClickHouseSchema.ClickHouseTable("t2", empty_col_list, indexes, + false); + ClickHouseTableReference table1_ref = new ClickHouseTableReference(table1, null); + ClickHouseTableReference table2_ref = new ClickHouseTableReference(table2, null); + ClickHouseSchema.ClickHouseColumn a1_col = new ClickHouseSchema.ClickHouseColumn("a1", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table1); + ClickHouseSchema.ClickHouseColumn b1_col = new ClickHouseSchema.ClickHouseColumn("b1", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table1); + ClickHouseSchema.ClickHouseColumn a2_col = new ClickHouseSchema.ClickHouseColumn("a2", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table2); + ClickHouseSchema.ClickHouseColumn b2_col = new ClickHouseSchema.ClickHouseColumn("b2", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, table2); + a1_col.setTable(table1); + b1_col.setTable(table1); + a2_col.setTable(table2); + b2_col.setTable(table2); + + ClickHouseColumnReference a1_ref = a1_col.asColumnReference(null); + ClickHouseColumnReference b1_ref = b1_col.asColumnReference(null); + ClickHouseColumnReference a2_ref = a2_col.asColumnReference(null); + ClickHouseColumnReference b2_ref = b2_col.asColumnReference(null); + + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(a1_ref, a2_ref, b1_ref, b2_ref)); + select.setFromClause(table1_ref); + ClickHouseExpression.ClickHouseJoin join = new ClickHouseExpression.ClickHouseJoin(table1_ref, table2_ref, + ClickHouseExpression.ClickHouseJoin.JoinType.CROSS); + select.setJoinClauses(Arrays.asList(join)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT t1.a1, t2.a2, t1.b1, t2.b2 FROM t1 JOIN t2"; + assertEquals(answer, result); + } + + @Test + void selectCrossJoinAliasedTest() { + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseColumn a1_col = new ClickHouseSchema.ClickHouseColumn("a1", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseColumn b1_col = new ClickHouseSchema.ClickHouseColumn("b1", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseColumn a2_col = new ClickHouseSchema.ClickHouseColumn("a2", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseColumn b2_col = new ClickHouseSchema.ClickHouseColumn("b2", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseTable table1 = new ClickHouseSchema.ClickHouseTable("t1", + Arrays.asList(a1_col, b1_col), indexes, false); + ClickHouseSchema.ClickHouseTable table2 = new ClickHouseSchema.ClickHouseTable("t2", + Arrays.asList(a2_col, b2_col), indexes, false); + a1_col.setTable(table1); + b1_col.setTable(table1); + a2_col.setTable(table2); + b2_col.setTable(table2); + + ClickHouseTableReference table1_ref = new ClickHouseTableReference(table1, "left"); + ClickHouseTableReference table2_ref = new ClickHouseTableReference(table2, "right"); + + List t1_col_ref = table1_ref.getColumnReferences(); + ClickHouseColumnReference a1_ref = t1_col_ref.get(0); + ClickHouseColumnReference a2_ref = t1_col_ref.get(1); + + List t2_col_ref = table2_ref.getColumnReferences(); + ClickHouseColumnReference b1_ref = t2_col_ref.get(0); + ClickHouseColumnReference b2_ref = t2_col_ref.get(1); + + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(a1_ref, a2_ref, b1_ref, b2_ref)); + select.setFromClause(table1_ref); + ClickHouseExpression.ClickHouseJoin join = new ClickHouseExpression.ClickHouseJoin(table1_ref, table2_ref, + ClickHouseExpression.ClickHouseJoin.JoinType.CROSS); + select.setJoinClauses(Arrays.asList(join)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT left.a1, left.b1, right.a2, right.b2 FROM t1 AS left JOIN t2 AS right"; + assertEquals(answer, result); + } + + @Test + void selectJoinONTest() { + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseColumn a1_col = new ClickHouseSchema.ClickHouseColumn("a1", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseColumn b1_col = new ClickHouseSchema.ClickHouseColumn("b1", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseColumn a2_col = new ClickHouseSchema.ClickHouseColumn("a2", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseColumn b2_col = new ClickHouseSchema.ClickHouseColumn("b2", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseTable table1 = new ClickHouseSchema.ClickHouseTable("t1", + Arrays.asList(a1_col, b1_col), indexes, false); + ClickHouseSchema.ClickHouseTable table2 = new ClickHouseSchema.ClickHouseTable("t2", + Arrays.asList(a2_col, b2_col), indexes, false); + a1_col.setTable(table1); + b1_col.setTable(table1); + a2_col.setTable(table2); + b2_col.setTable(table2); + + ClickHouseTableReference table1_ref = new ClickHouseTableReference(table1, null); + ClickHouseTableReference table2_ref = new ClickHouseTableReference(table2, null); + + List t1_col_ref = table1_ref.getColumnReferences(); + ClickHouseColumnReference a1_ref = t1_col_ref.get(0); + ClickHouseColumnReference b1_ref = t1_col_ref.get(1); + + List t2_col_ref = table2_ref.getColumnReferences(); + ClickHouseColumnReference a2_ref = t2_col_ref.get(0); + ClickHouseColumnReference b2_ref = t2_col_ref.get(1); + + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(a1_ref, a2_ref, b1_ref, b2_ref)); + select.setFromClause(table1_ref); + ClickHouseExpression.ClickHouseJoinOnClause on = new ClickHouseExpression.ClickHouseJoinOnClause(a1_ref, + a2_ref); + ClickHouseExpression.ClickHouseJoin join = new ClickHouseExpression.ClickHouseJoin(table1_ref, table2_ref, + ClickHouseExpression.ClickHouseJoin.JoinType.INNER, on); + select.setJoinClauses(Arrays.asList(join)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT t1.a1, t2.a2, t1.b1, t2.b2 FROM t1 INNER JOIN t2 ON ((t1.a1)=(t2.a2))"; + assertEquals(answer, result); + } + + @Test + void selectJoinONAliasedTest() { + List indexes = Collections.emptyList(); + ClickHouseSchema.ClickHouseColumn a1_col = new ClickHouseSchema.ClickHouseColumn("a1", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseColumn b1_col = new ClickHouseSchema.ClickHouseColumn("b1", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseColumn a2_col = new ClickHouseSchema.ClickHouseColumn("a2", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseColumn b2_col = new ClickHouseSchema.ClickHouseColumn("b2", + ClickHouseSchema.ClickHouseLancerDataType.getRandom(), false, false, null); + ClickHouseSchema.ClickHouseTable table1 = new ClickHouseSchema.ClickHouseTable("t1", + Arrays.asList(a1_col, b1_col), indexes, false); + ClickHouseSchema.ClickHouseTable table2 = new ClickHouseSchema.ClickHouseTable("t2", + Arrays.asList(a2_col, b2_col), indexes, false); + a1_col.setTable(table1); + b1_col.setTable(table1); + a2_col.setTable(table2); + b2_col.setTable(table2); + + ClickHouseTableReference table1_ref = new ClickHouseTableReference(table1, "left"); + ClickHouseTableReference table2_ref = new ClickHouseTableReference(table2, "right"); + + List t1_col_ref = table1_ref.getColumnReferences(); + ClickHouseColumnReference a1_ref = t1_col_ref.get(0); + ClickHouseColumnReference b1_ref = t1_col_ref.get(1); + + List t2_col_ref = table2_ref.getColumnReferences(); + ClickHouseColumnReference a2_ref = t2_col_ref.get(0); + ClickHouseColumnReference b2_ref = t2_col_ref.get(1); + + ClickHouseSelect select = new ClickHouseSelect(); + select.setFetchColumns(Arrays.asList(a1_ref, a2_ref, b1_ref, b2_ref)); + select.setFromClause(table1_ref); + ClickHouseExpression.ClickHouseJoinOnClause on = new ClickHouseExpression.ClickHouseJoinOnClause(a1_ref, + a2_ref); + ClickHouseExpression.ClickHouseJoin join = new ClickHouseExpression.ClickHouseJoin(table1_ref, table2_ref, + ClickHouseExpression.ClickHouseJoin.JoinType.INNER, on); + select.setJoinClauses(Arrays.asList(join)); + String result = ClickHouseVisitor.asString(select); + String answer = "SELECT left.a1, right.a2, left.b1, right.b2 FROM t1 AS left INNER JOIN t2 AS right ON ((left.a1)=(right.a2))"; + assertEquals(answer, result); + } +} diff --git a/test/sqlancer/common/query/SQLQueryErrorTest.java b/test/sqlancer/common/query/SQLQueryErrorTest.java new file mode 100644 index 000000000..dea1e7a53 --- /dev/null +++ b/test/sqlancer/common/query/SQLQueryErrorTest.java @@ -0,0 +1,92 @@ +package sqlancer.common.query; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SQLQueryErrorTest { + @Test + public void testSettersAndGetters() { + SQLQueryError error = new SQLQueryError(); + error.setLevel(SQLQueryError.ErrorLevel.ERROR); + error.setCode(123); + error.setMessage("Test message"); + assertEquals(SQLQueryError.ErrorLevel.ERROR, error.getLevel()); + assertEquals(123, error.getCode()); + assertEquals("Test message", error.getMessage()); + } + + @Test + public void testHasSameLevel() { + SQLQueryError e1 = new SQLQueryError(); + SQLQueryError e2 = new SQLQueryError(); + e1.setLevel(SQLQueryError.ErrorLevel.WARNING); + e2.setLevel(SQLQueryError.ErrorLevel.WARNING); + assertTrue(e1.hasSameLevel(e2)); + e2.setLevel(SQLQueryError.ErrorLevel.ERROR); + assertFalse(e1.hasSameLevel(e2)); + } + + @Test + public void testHasSameCodeAndMessage() { + SQLQueryError e1 = new SQLQueryError(); + SQLQueryError e2 = new SQLQueryError(); + e1.setCode(1); + e2.setCode(1); + e1.setMessage("msg"); + e2.setMessage("msg"); + assertTrue(e1.hasSameCodeAndMessage(e2)); + e2.setCode(2); + assertFalse(e1.hasSameCodeAndMessage(e2)); + e2.setCode(1); + e2.setMessage("other"); + assertFalse(e1.hasSameCodeAndMessage(e2)); + } + + @Test + public void testEquals() { + SQLQueryError e1 = new SQLQueryError(); + SQLQueryError e2 = new SQLQueryError(); + e1.setLevel(SQLQueryError.ErrorLevel.ERROR); + e1.setCode(1); + e1.setMessage("msg"); + e2.setLevel(SQLQueryError.ErrorLevel.ERROR); + e2.setCode(1); + e2.setMessage("msg"); + assertEquals(e1, e2); + e2.setLevel(SQLQueryError.ErrorLevel.WARNING); + assertNotEquals(e1, e2); + } + + @Test + public void testToString() { + SQLQueryError e = new SQLQueryError(); + e.setLevel(SQLQueryError.ErrorLevel.ERROR); + e.setCode(1); + e.setMessage("msg"); + String str = e.toString(); + assertTrue(str.contains("Level: ERROR")); + assertTrue(str.contains("Code: 1")); + assertTrue(str.contains("Message: msg")); + } + + @Test + public void testCompareTo() { + SQLQueryError e1 = new SQLQueryError(); + SQLQueryError e2 = new SQLQueryError(); + e1.setCode(1); + e2.setCode(2); + assertTrue(e1.compareTo(e2) < 0); + e2.setCode(1); + e1.setLevel(SQLQueryError.ErrorLevel.ERROR); + e2.setLevel(SQLQueryError.ErrorLevel.WARNING); + assertTrue(e1.compareTo(e2) > 0 || e1.compareTo(e2) < 0); + e2.setLevel(SQLQueryError.ErrorLevel.ERROR); + e1.setMessage("a"); + e2.setMessage("b"); + assertTrue(e1.compareTo(e2) < 0); + } +} diff --git a/test/sqlancer/dbms/TestClickHouse.java b/test/sqlancer/dbms/TestClickHouse.java index 2cf53eef2..87d7713fe 100644 --- a/test/sqlancer/dbms/TestClickHouse.java +++ b/test/sqlancer/dbms/TestClickHouse.java @@ -11,62 +11,128 @@ public class TestClickHouse { @Test public void testClickHouseTLPWhereGroupBy() { - String clickHouseAvailable = System.getenv("CLICKHOUSE_AVAILABLE"); - boolean clickHouseIsAvailable = clickHouseAvailable != null && clickHouseAvailable.equalsIgnoreCase("true"); - assumeTrue(clickHouseIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); assertEquals(0, Main.executeMain("--timeout-seconds", "60", "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", - "5", "--username", "default", "--password", "", "clickhouse", "--oracle", "TLPWhere", - "--oracle", "TLPGroupBy")); + "5", "--username", "default", "--password", "", "--database-prefix", "T1_", "clickhouse", + "--oracle", "TLPWhere", "--oracle", "TLPGroupBy")); } @Test public void testClickHouseTLPWhere() { - String clickHouseAvailable = System.getenv("CLICKHOUSE_AVAILABLE"); - boolean clickHouseIsAvailable = clickHouseAvailable != null && clickHouseAvailable.equalsIgnoreCase("true"); - assumeTrue(clickHouseIsAvailable); - assertEquals(0, Main.executeMain("--timeout-seconds", "60", "--num-queries", TestConfig.NUM_QUERIES, - "--num-threads", "5", "--username", "default", "--password", "", "clickhouse", "--oracle", "TLPWhere")); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); + assertEquals(0, + Main.executeMain("--timeout-seconds", "60", "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", + "5", "--username", "default", "--password", "", "--database-prefix", "T2_", "clickhouse", + "--oracle", "TLPWhere")); } @Test public void testClickHouseTLPHaving() { - String clickHouseAvailable = System.getenv("CLICKHOUSE_AVAILABLE"); - boolean clickHouseIsAvailable = clickHouseAvailable != null && clickHouseAvailable.equalsIgnoreCase("true"); - assumeTrue(clickHouseIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); assertEquals(0, - Main.executeMain("--timeout-seconds", "60", "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", - "5", "--username", "default", "--password", "", "clickhouse", "--oracle", "TLPHaving")); + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "1", "--username", "default", + "--password", "", "--database-prefix", "T3_", "clickhouse", "--oracle", "TLPHaving")); } @Test public void testClickHouseTLPGroupBy() { - String clickHouseAvailable = System.getenv("CLICKHOUSE_AVAILABLE"); - boolean clickHouseIsAvailable = clickHouseAvailable != null && clickHouseAvailable.equalsIgnoreCase("true"); - assumeTrue(clickHouseIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); assertEquals(0, - Main.executeMain("--timeout-seconds", "60", "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", - "5", "--username", "default", "--password", "", "clickhouse", "--oracle", "TLPGroupBy")); + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "5", "--username", "default", + "--password", "", "--database-prefix", "T4_", "clickhouse", "--oracle", "TLPGroupBy")); } @Test public void testClickHouseTLPDistinct() { - String clickHouseAvailable = System.getenv("CLICKHOUSE_AVAILABLE"); - boolean clickHouseIsAvailable = clickHouseAvailable != null && clickHouseAvailable.equalsIgnoreCase("true"); - assumeTrue(clickHouseIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); assertEquals(0, - Main.executeMain("--timeout-seconds", "60", "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", - "5", "--username", "default", "--password", "", "clickhouse", "--oracle", "TLPDistinct")); + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "5", "--username", "default", + "--password", "", "--database-prefix", "T5_", "clickhouse", "--oracle", "TLPDistinct")); } @Test public void testClickHouseTLPAggregate() { - String clickHouseAvailable = System.getenv("CLICKHOUSE_AVAILABLE"); - boolean clickHouseIsAvailable = clickHouseAvailable != null && clickHouseAvailable.equalsIgnoreCase("true"); - assumeTrue(clickHouseIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); + assertEquals(0, + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "5", "--username", "default", + "--password", "", "--database-prefix", "T6_", "clickhouse", "--oracle", "TLPAggregate")); + } + + @Test + public void testClickHouseNoREC() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); + assertEquals(0, + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "1", "--username", "default", + "--password", "", "--database-prefix", "T7_", "clickhouse", "--oracle", "NoREC")); + } + + @Test + public void testClickHouseTLPWhereGroupByWithJoins() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); assertEquals(0, Main.executeMain("--timeout-seconds", "60", "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", - "5", "--username", "default", "--password", "", "clickhouse", "--oracle", "TLPAggregate")); + "5", "--username", "default", "--password", "", "--database-prefix", "T8_", "clickhouse", + "--oracle", "TLPWhere", "--oracle", "TLPGroupBy")); + } + + @Test + public void testClickHouseTLPWhereWithJoins() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); + assertEquals(0, + Main.executeMain("--timeout-seconds", "60", "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", + "5", "--username", "default", "--password", "", "--database-prefix", "T9_", "clickhouse", + "--oracle", "TLPWhere")); + } + + @Test + public void testClickHouseTLPHavingWithJoins() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); + assertEquals(0, + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "1", "--username", "default", + "--password", "", "--database-prefix", "T10_", "clickhouse", "--oracle", "TLPHaving")); + } + + @Test + public void testClickHouseTLPGroupByWithJoins() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); + assertEquals(0, + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "5", "--username", "default", + "--password", "", "--database-prefix", "T11_", "clickhouse", "--oracle", "TLPGroupBy")); + } + + @Test + public void testClickHouseTLPDistinctWithJoins() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); + assertEquals(0, + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "5", "--username", "default", + "--password", "", "--database-prefix", "T12_", "clickhouse", "--oracle", "TLPDistinct")); + } + + @Test + public void testClickHouseTLPAggregateWithJoins() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); + assertEquals(0, + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "5", "--username", "default", + "--password", "", "--database-prefix", "T13_", "clickhouse", "--oracle", "TLPAggregate")); + } + + @Test + public void testClickHouseNoRECWithJoins() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CLICKHOUSE_ENV)); + assertEquals(0, + Main.executeMain("--log-each-select", "true", "--print-failed", "false", "--timeout-seconds", "60", + "--num-queries", TestConfig.NUM_QUERIES, "--num-threads", "1", "--username", "default", + "--password", "", "--database-prefix", "T14_", "clickhouse", "--oracle", "NoREC")); } } diff --git a/test/sqlancer/dbms/TestCnosDBNoREC.java b/test/sqlancer/dbms/TestCnosDBNoREC.java new file mode 100644 index 000000000..1a89a972a --- /dev/null +++ b/test/sqlancer/dbms/TestCnosDBNoREC.java @@ -0,0 +1,22 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestCnosDBNoREC { + + @Test + public void testCnosDBNoREC() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CNOSDB_ENV)); + // Run with 0 queries as current implementation is resulting in database crashes + assertEquals(0, + Main.executeMain(new String[] { "--host", "127.0.0.1", "--port", "8902", "--username", "root", + "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-queries", "0", "cnosdb", + "--oracle", "NOREC" })); + } + +} diff --git a/test/sqlancer/dbms/TestCnosDBTLP.java b/test/sqlancer/dbms/TestCnosDBTLP.java new file mode 100644 index 000000000..4b12aa409 --- /dev/null +++ b/test/sqlancer/dbms/TestCnosDBTLP.java @@ -0,0 +1,22 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestCnosDBTLP { + + @Test + public void testCnosDBTLP() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.CNOSDB_ENV)); + // Run with 0 queries as current implementation is resulting in database crashes + assertEquals(0, + Main.executeMain(new String[] { "--host", "127.0.0.1", "--port", "8902", "--username", "root", + "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-queries", "0", "cnosdb", + "--oracle", "QUERY_PARTITIONING" })); + } + +} diff --git a/test/sqlancer/dbms/TestCockroachDBCERT.java b/test/sqlancer/dbms/TestCockroachDBCERT.java new file mode 100644 index 000000000..4d7140e40 --- /dev/null +++ b/test/sqlancer/dbms/TestCockroachDBCERT.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestCockroachDBCERT { + + @Test + public void testCockroachDBCERT() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.COCKROACHDB_ENV)); + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "cockroachdb", "--oracle", "CERT" })); + } + +} diff --git a/test/sqlancer/dbms/TestCockroachDB.java b/test/sqlancer/dbms/TestCockroachDBNoREC.java similarity index 58% rename from test/sqlancer/dbms/TestCockroachDB.java rename to test/sqlancer/dbms/TestCockroachDBNoREC.java index 5eb701e3b..bf1a51193 100644 --- a/test/sqlancer/dbms/TestCockroachDB.java +++ b/test/sqlancer/dbms/TestCockroachDBNoREC.java @@ -7,15 +7,13 @@ import sqlancer.Main; -public class TestCockroachDB { +public class TestCockroachDBNoREC { @Test - public void testMySQL() { - String cockroachDB = System.getenv("COCKROACHDB_AVAILABLE"); - boolean cockroachDBIsAvailable = cockroachDB != null && cockroachDB.equalsIgnoreCase("true"); - assumeTrue(cockroachDBIsAvailable); + public void testCockroachDBNoREC() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.COCKROACHDB_ENV)); assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, - "--num-queries", TestConfig.NUM_QUERIES, "cockroachdb" })); + "--num-queries", TestConfig.NUM_QUERIES, "cockroachdb", "--oracle", "NOREC" })); } } diff --git a/test/sqlancer/dbms/TestCockroachDBTLP.java b/test/sqlancer/dbms/TestCockroachDBTLP.java new file mode 100644 index 000000000..f916844ed --- /dev/null +++ b/test/sqlancer/dbms/TestCockroachDBTLP.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestCockroachDBTLP { + + @Test + public void testCockroachDBTLP() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.COCKROACHDB_ENV)); + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-queries", TestConfig.NUM_QUERIES, "cockroachdb", "--oracle", "QUERY_PARTITIONING" })); + } + +} diff --git a/test/sqlancer/dbms/TestConfig.java b/test/sqlancer/dbms/TestConfig.java index 9fac06920..f6be45648 100644 --- a/test/sqlancer/dbms/TestConfig.java +++ b/test/sqlancer/dbms/TestConfig.java @@ -3,4 +3,26 @@ public class TestConfig { public static final String NUM_QUERIES = "1000"; public static final String SECONDS = "300"; + + public static final String CLICKHOUSE_ENV = "CLICKHOUSE_AVAILABLE"; + public static final String CNOSDB_ENV = "CNOSDB_AVAILABLE"; + public static final String COCKROACHDB_ENV = "COCKROACHDB_AVAILABLE"; + public static final String DATABEND_ENV = "DATABEND_AVAILABLE"; + public static final String DATAFUSION_ENV = "DATAFUSION_AVAILABLE"; + public static final String DORIS_ENV = "DORIS_AVAILABLE"; + public static final String HIVE_ENV = "HIVE_AVAILABLE"; + public static final String SPARK_ENV = "SPARK_AVAILABLE"; + public static final String MARIADB_ENV = "MARIADB_AVAILABLE"; + public static final String MATERIALIZE_ENV = "MATERIALIZE_AVAILABLE"; + public static final String MYSQL_ENV = "MYSQL_AVAILABLE"; + public static final String OCEANBASE_ENV = "OCEANBASE_AVAILABLE"; + public static final String POSTGRES_ENV = "POSTGRES_AVAILABLE"; + public static final String PRESTO_ENV = "PRESTO_AVAILABLE"; + public static final String TIDB_ENV = "TIDB_AVAILABLE"; + public static final String YUGABYTE_ENV = "YUGABYTE_AVAILABLE"; + + public static boolean isEnvironmentTrue(String key) { + String value = System.getenv(key); + return value != null && value.equalsIgnoreCase("true"); + } } diff --git a/test/sqlancer/dbms/TestDataFusion.java b/test/sqlancer/dbms/TestDataFusion.java new file mode 100644 index 000000000..b2b5e2a1a --- /dev/null +++ b/test/sqlancer/dbms/TestDataFusion.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestDataFusion { + @Test + public void testDataFusion() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.DATAFUSION_ENV)); + + assertEquals(0, Main.executeMain("--random-seed", "0", "--num-threads", "1", // TODO(datafusion) update when + // multithread is supported + "--timeout-seconds", TestConfig.SECONDS, "--num-queries", TestConfig.NUM_QUERIES, "datafusion")); + } +} diff --git a/test/sqlancer/dbms/TestDatabend.java b/test/sqlancer/dbms/TestDatabend.java deleted file mode 100644 index a9575fd7d..000000000 --- a/test/sqlancer/dbms/TestDatabend.java +++ /dev/null @@ -1,107 +0,0 @@ -package sqlancer.dbms; - -import org.junit.jupiter.api.Test; -import sqlancer.Main; -import sqlancer.Randomly; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assumptions.assumeTrue; - -public class TestDatabend { - - @Test - public void testDatabendNoREC() { - String databendAvailable = System.getenv("DATABEND_AVAILABLE"); - boolean databendIsAvailable = databendAvailable != null && databendAvailable.equalsIgnoreCase("true"); - assumeTrue(databendIsAvailable); - assertEquals(0, - Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", - "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", - "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.NUMERIC), - "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "NOREC")); - } - - // TODO Databend待修复的bug(union schema error mismatch)https://github.com/datafuselabs/databend/issues/7463 - @Test - public void testDatabendTLPQueryPartitioning() { - String databendAvailable = System.getenv("DATABEND_AVAILABLE"); - boolean databendIsAvailable = databendAvailable != null && databendAvailable.equalsIgnoreCase("true"); - assumeTrue(databendIsAvailable); - assertEquals(0, - Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", - "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", - "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.NUMERIC), - "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "QUERY_PARTITIONING")); - } - - // @Test - // public void testDatabendTLPWhere() { - // String databendAvailable = System.getenv("DATABEND_AVAILABLE"); - // boolean databendIsAvailable = databendAvailable != null && databendAvailable.equalsIgnoreCase("true"); - // assumeTrue(databendIsAvailable); - // assertEquals(0, - // Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", - // "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", - // "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.NUMERIC), - // "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "WHERE")); - // } - // - // @Test - // public void testDatabendTLPGroupBy() { - // String databendAvailable = System.getenv("DATABEND_AVAILABLE"); - // boolean databendIsAvailable = databendAvailable != null && databendAvailable.equalsIgnoreCase("true"); - // assumeTrue(databendIsAvailable); - // assertEquals(0, - // Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", - // "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", - // "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.NUMERIC), - // "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "GROUP_BY")); - // } - // - // @Test - // public void testDatabendTLPHaving() { - // String databendAvailable = System.getenv("DATABEND_AVAILABLE"); - // boolean databendIsAvailable = databendAvailable != null && databendAvailable.equalsIgnoreCase("true"); - // assumeTrue(databendIsAvailable); - // assertEquals(0, - // Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", - // "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", - // "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.NUMERIC), - // "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "HAVING")); - // } - // - // @Test - // public void testDatabendTLPDistinct() { - // String databendAvailable = System.getenv("DATABEND_AVAILABLE"); - // boolean databendIsAvailable = databendAvailable != null && databendAvailable.equalsIgnoreCase("true"); - // assumeTrue(databendIsAvailable); - // assertEquals(0, - // Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", - // "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", - // "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.NUMERIC), - // "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "DISTINCT")); - // } - // - // @Test - // public void testDatabendTLPAggregate() { - // String databendAvailable = System.getenv("DATABEND_AVAILABLE"); - // boolean databendIsAvailable = databendAvailable != null && databendAvailable.equalsIgnoreCase("true"); - // assumeTrue(databendIsAvailable); - // assertEquals(0, - // Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", - // "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", - // "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.NUMERIC), - // "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "AGGREGATE")); - // } - - // @Test - // void testConnection() { - // assertEquals(0, - // Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, - // "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", - // "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.NUMERIC), - // "--host", "127.0.0.1", "--port", "3307", "--username", "user1", "--password", "1234", - // "databend", "--oracle", "HAVING" })); - // } - -} diff --git a/test/sqlancer/dbms/TestDatabendNoREC.java b/test/sqlancer/dbms/TestDatabendNoREC.java new file mode 100644 index 000000000..679f8c161 --- /dev/null +++ b/test/sqlancer/dbms/TestDatabendNoREC.java @@ -0,0 +1,23 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.Randomly; + +public class TestDatabendNoREC { + + @Test + public void testDatabendNoREC() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.DATABEND_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", + "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", + "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.ALPHANUMERIC), + "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "NOREC")); + } + +} diff --git a/test/sqlancer/dbms/TestDatabendPQS.java b/test/sqlancer/dbms/TestDatabendPQS.java new file mode 100644 index 000000000..fba733d44 --- /dev/null +++ b/test/sqlancer/dbms/TestDatabendPQS.java @@ -0,0 +1,23 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.Randomly; + +public class TestDatabendPQS { + + @Test + public void testDatabendPQS() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.DATABEND_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", + "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", + "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.ALPHANUMERIC), + "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "PQS")); + } + +} diff --git a/test/sqlancer/dbms/TestDatabendTLP.java b/test/sqlancer/dbms/TestDatabendTLP.java new file mode 100644 index 000000000..27ba53416 --- /dev/null +++ b/test/sqlancer/dbms/TestDatabendTLP.java @@ -0,0 +1,21 @@ +package sqlancer.dbms; + +import org.junit.jupiter.api.Test; +import sqlancer.Main; +import sqlancer.Randomly; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class TestDatabendTLP { + + @Test + public void testDatabendTLPQueryPartitioning() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.DATABEND_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", + "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "databend", + "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.ALPHANUMERIC), + "--host", "127.0.0.1", "--port", "3307", "databend", "--oracle", "QUERY_PARTITIONING")); + } +} diff --git a/test/sqlancer/dbms/TestDorisNoREC.java b/test/sqlancer/dbms/TestDorisNoREC.java new file mode 100644 index 000000000..76a9d50a8 --- /dev/null +++ b/test/sqlancer/dbms/TestDorisNoREC.java @@ -0,0 +1,26 @@ +package sqlancer.dbms; + +import org.junit.jupiter.api.Test; +import sqlancer.Main; +import sqlancer.Randomly; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class TestDorisNoREC { + private final String host = "127.0.0.1"; + private final String port = "9030"; + private final String username = "sqlancer"; + private final String password = "sqlancer"; + + @Test + public void testdorisNoREC() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.DORIS_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", + "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "doris", + "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.ALPHANUMERIC), + "--username", username, "--password", password, "--host", host, "--port", port, "doris", + "--oracle", "NOREC")); + } +} diff --git a/test/sqlancer/dbms/TestDorisPQS.java b/test/sqlancer/dbms/TestDorisPQS.java new file mode 100644 index 000000000..5760003d3 --- /dev/null +++ b/test/sqlancer/dbms/TestDorisPQS.java @@ -0,0 +1,26 @@ +package sqlancer.dbms; + +import org.junit.jupiter.api.Test; +import sqlancer.Main; +import sqlancer.Randomly; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class TestDorisPQS { + private final String host = "127.0.0.1"; + private final String port = "9030"; + private final String username = "sqlancer"; + private final String password = "sqlancer"; + + @Test + public void testdorisPQS() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.DORIS_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", + "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "doris", + "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.ALPHANUMERIC), + "--username", username, "--password", password, "--host", host, "--port", port, "doris", + "--oracle", "PQS")); + } +} diff --git a/test/sqlancer/dbms/TestDorisTLP.java b/test/sqlancer/dbms/TestDorisTLP.java new file mode 100644 index 000000000..f8bea9e47 --- /dev/null +++ b/test/sqlancer/dbms/TestDorisTLP.java @@ -0,0 +1,26 @@ +package sqlancer.dbms; + +import org.junit.jupiter.api.Test; +import sqlancer.Main; +import sqlancer.Randomly; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class TestDorisTLP { + private final String host = "127.0.0.1"; + private final String port = "9030"; + private final String username = "sqlancer"; + private final String password = "sqlancer"; + + @Test + public void testdorisTLP() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.DORIS_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", + "--num-queries", TestConfig.NUM_QUERIES, "--database-prefix", "doris", + "--random-string-generation", String.valueOf(Randomly.StringGenerationStrategy.ALPHANUMERIC), + "--username", username, "--password", password, "--host", host, "--port", port, "doris", + "--oracle", "QUERY_PARTITIONING")); + } +} diff --git a/test/sqlancer/dbms/TestSQLite3.java b/test/sqlancer/dbms/TestDuckDBNoREC.java similarity index 68% rename from test/sqlancer/dbms/TestSQLite3.java rename to test/sqlancer/dbms/TestDuckDBNoREC.java index f26b449c1..dea1c70dd 100644 --- a/test/sqlancer/dbms/TestSQLite3.java +++ b/test/sqlancer/dbms/TestDuckDBNoREC.java @@ -6,13 +6,11 @@ import sqlancer.Main; -public class TestSQLite3 { - +public class TestDuckDBNoREC { @Test - public void testSqlite() { + public void testDuckDBNoREC() { // run with one thread due to multithreading issues, see https://github.com/sqlancer/sqlancer/pull/45 assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, - "--num-threads", "1", "--num-queries", "0", "sqlite3" })); + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "duckdb", "--oracle", "NOREC" })); } - } diff --git a/test/sqlancer/dbms/TestDuckDB.java b/test/sqlancer/dbms/TestDuckDBTLP.java similarity index 90% rename from test/sqlancer/dbms/TestDuckDB.java rename to test/sqlancer/dbms/TestDuckDBTLP.java index d36d8610f..322e3eeb4 100644 --- a/test/sqlancer/dbms/TestDuckDB.java +++ b/test/sqlancer/dbms/TestDuckDBTLP.java @@ -6,10 +6,10 @@ import sqlancer.Main; -public class TestDuckDB { +public class TestDuckDBTLP { @Test - public void testDuckDB() { + public void testDuckDBTLP() { // run with one thread due to multithreading issues, see https://github.com/sqlancer/sqlancer/pull/45 assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, diff --git a/test/sqlancer/dbms/TestH2.java b/test/sqlancer/dbms/TestH2.java index 8f13170ff..ec6962101 100644 --- a/test/sqlancer/dbms/TestH2.java +++ b/test/sqlancer/dbms/TestH2.java @@ -1,7 +1,6 @@ package sqlancer.dbms; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assumptions.assumeTrue; import org.junit.jupiter.api.Test; @@ -10,10 +9,7 @@ public class TestH2 { @Test - public void testDuckDB() { - String h2Available = System.getenv("H2_AVAILABLE"); - boolean mariaDBIsAvailable = h2Available != null && h2Available.equalsIgnoreCase("true"); - assumeTrue(mariaDBIsAvailable); + public void testH2DB() { assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "h2" })); diff --git a/test/sqlancer/dbms/TestHSQLDBNoREC.java b/test/sqlancer/dbms/TestHSQLDBNoREC.java new file mode 100644 index 000000000..721fbb126 --- /dev/null +++ b/test/sqlancer/dbms/TestHSQLDBNoREC.java @@ -0,0 +1,15 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestHSQLDBNoREC { + @Test + public void testHSQLDBNoREC() { + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "hsqldb", "--oracle", "NOREC" })); + } +} diff --git a/test/sqlancer/dbms/TestHSQLDBTLP.java b/test/sqlancer/dbms/TestHSQLDBTLP.java new file mode 100644 index 000000000..a0b9c18d9 --- /dev/null +++ b/test/sqlancer/dbms/TestHSQLDBTLP.java @@ -0,0 +1,15 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestHSQLDBTLP { + @Test + public void testHSQLDBTLP() { + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "hsqldb", "--oracle", "WHERE" })); + } +} diff --git a/test/sqlancer/dbms/TestHiveTLP.java b/test/sqlancer/dbms/TestHiveTLP.java new file mode 100644 index 000000000..5f8692f3f --- /dev/null +++ b/test/sqlancer/dbms/TestHiveTLP.java @@ -0,0 +1,20 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestHiveTLP { + + @Test + public void testHiveTLPWhere() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.HIVE_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--canonicalize-sql-strings", "false", "--random-seed", "0", + "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "1", "--num-queries", + TestConfig.NUM_QUERIES, "hive", "--oracle", "TLPWhere" })); + } +} \ No newline at end of file diff --git a/test/sqlancer/dbms/TestMariaDB.java b/test/sqlancer/dbms/TestMariaDB.java index e5c188d28..eb26bfb4a 100644 --- a/test/sqlancer/dbms/TestMariaDB.java +++ b/test/sqlancer/dbms/TestMariaDB.java @@ -3,19 +3,45 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assumptions.assumeTrue; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import sqlancer.Main; +import sqlancer.mariadb.MariaDBSchema; +import sqlancer.mariadb.ast.MariaDBColumnName; +import sqlancer.mariadb.ast.MariaDBSelectStatement; +import sqlancer.mariadb.ast.MariaDBTableReference; +import sqlancer.mariadb.ast.MariaDBVisitor; public class TestMariaDB { + @Test + public void testSelectAsString() { + MariaDBSchema.MariaDBColumn c0 = new MariaDBSchema.MariaDBColumn("c0", MariaDBSchema.MariaDBDataType.INT, true, + 0); + MariaDBSchema.MariaDBColumn c1 = new MariaDBSchema.MariaDBColumn("c1", MariaDBSchema.MariaDBDataType.INT, false, + 0); + List columns = List.of(c0, c1); + List indices = List.of(); + MariaDBSchema.MariaDBTable t1 = new MariaDBSchema.MariaDBTable("t1", columns, indices, + MariaDBSchema.MariaDBTable.MariaDBEngine.INNO_DB); + MariaDBSchema.MariaDBTables tables = new MariaDBSchema.MariaDBTables(List.of(t1)); + + MariaDBSelectStatement select = new MariaDBSelectStatement(); + select.setFetchColumns(tables.getColumns().stream().map(MariaDBColumnName::new).collect(Collectors.toList())); + select.setFromList(tables.getTables().stream().map(MariaDBTableReference::new).collect(Collectors.toList())); + + String selectString = MariaDBVisitor.asString(select); + assertEquals("SELECT c0, c1 FROM t1", selectString); + } + @Test public void testMariaDB() { - String mariaDBAvailable = System.getenv("MARIADB_AVAILABLE"); - boolean mariaDBIsAvailable = mariaDBAvailable != null && mariaDBAvailable.equalsIgnoreCase("true"); - assumeTrue(mariaDBIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.MARIADB_ENV)); assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, - "--num-queries", "0", "mariadb" })); + "--num-queries", TestConfig.NUM_QUERIES, "mariadb" })); } } diff --git a/test/sqlancer/dbms/TestMaterializeNoREC.java b/test/sqlancer/dbms/TestMaterializeNoREC.java new file mode 100644 index 000000000..e4ac980ce --- /dev/null +++ b/test/sqlancer/dbms/TestMaterializeNoREC.java @@ -0,0 +1,21 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestMaterializeNoREC { + + @Test + public void test() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.MATERIALIZE_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "--username", "materialize", + "materialize", "--oracle", "NOREC", "--set-max-tables-mvs", "true" })); + } + +} diff --git a/test/sqlancer/dbms/TestMaterializePQS.java b/test/sqlancer/dbms/TestMaterializePQS.java new file mode 100644 index 000000000..0c6be974d --- /dev/null +++ b/test/sqlancer/dbms/TestMaterializePQS.java @@ -0,0 +1,22 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestMaterializePQS { + + @Test + public void test() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.MATERIALIZE_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "--random-string-generation", + "ALPHANUMERIC_SPECIALCHAR", "--username", "materialize", "materialize", "--oracle", "pqs", + "--set-max-tables-mvs", "true" })); + } + +} diff --git a/test/sqlancer/dbms/TestMaterializeTLP.java b/test/sqlancer/dbms/TestMaterializeTLP.java new file mode 100644 index 000000000..a98b3e053 --- /dev/null +++ b/test/sqlancer/dbms/TestMaterializeTLP.java @@ -0,0 +1,21 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestMaterializeTLP { + + @Test + public void test() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.MATERIALIZE_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "--username", "materialize", + "materialize", "--set-max-tables-mvs", "true", "--oracle", "QUERY_PARTITIONING" })); + } + +} diff --git a/test/sqlancer/dbms/TestMySQLCERT.java b/test/sqlancer/dbms/TestMySQLCERT.java new file mode 100644 index 000000000..a7704a545 --- /dev/null +++ b/test/sqlancer/dbms/TestMySQLCERT.java @@ -0,0 +1,22 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestMySQLCERT { + + @Test + public void testMySQL() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.MYSQL_ENV)); + // Run with 0 queries as there are false positives for every mutation + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--max-expression-depth", "1", "--num-threads", "1", "--num-queries", "0", "mysql", "--oracle", + "CERT" })); + } + +} diff --git a/test/sqlancer/dbms/TestMySQLDQE.java b/test/sqlancer/dbms/TestMySQLDQE.java new file mode 100644 index 000000000..c137c1d0a --- /dev/null +++ b/test/sqlancer/dbms/TestMySQLDQE.java @@ -0,0 +1,21 @@ +package sqlancer.dbms; + +import org.junit.jupiter.api.Test; +import sqlancer.Main; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class TestMySQLDQE { + + @Test + public void testMySQL() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.MYSQL_ENV)); + // Run with 0 queries as there are false positives for every mutation + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--max-expression-depth", "1", "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, + "mysql", "--oracle", "DQE" })); + } + +} diff --git a/test/sqlancer/dbms/TestMySQLPQS.java b/test/sqlancer/dbms/TestMySQLPQS.java index ba6a9c501..6f1b0786f 100644 --- a/test/sqlancer/dbms/TestMySQLPQS.java +++ b/test/sqlancer/dbms/TestMySQLPQS.java @@ -9,12 +9,9 @@ public class TestMySQLPQS { - String mysqlAvailable = System.getenv("MYSQL_AVAILABLE"); - boolean mysqlIsAvailable = mysqlAvailable != null && mysqlAvailable.equalsIgnoreCase("true"); - @Test public void testPQS() { - assumeTrue(mysqlIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.MYSQL_ENV)); assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", "--random-string-generation", "ALPHANUMERIC", "--database-prefix", diff --git a/test/sqlancer/dbms/TestMySQLTLP.java b/test/sqlancer/dbms/TestMySQLTLP.java index 3d3e7d107..6254c58f3 100644 --- a/test/sqlancer/dbms/TestMySQLTLP.java +++ b/test/sqlancer/dbms/TestMySQLTLP.java @@ -9,12 +9,9 @@ public class TestMySQLTLP { - String mysqlAvailable = System.getenv("MYSQL_AVAILABLE"); - boolean mysqlIsAvailable = mysqlAvailable != null && mysqlAvailable.equalsIgnoreCase("true"); - @Test public void testMySQL() { - assumeTrue(mysqlIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.MYSQL_ENV)); assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--max-expression-depth", "1", "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, diff --git a/test/sqlancer/dbms/TestOceanBaseNoREC.java b/test/sqlancer/dbms/TestOceanBaseNoREC.java index 9c650265c..c687590de 100644 --- a/test/sqlancer/dbms/TestOceanBaseNoREC.java +++ b/test/sqlancer/dbms/TestOceanBaseNoREC.java @@ -9,12 +9,9 @@ public class TestOceanBaseNoREC { - String oceanBaseAvailable = System.getenv("OCEANBASE_AVAILABLE"); - boolean oceanBaseIsAvailable = oceanBaseAvailable != null && oceanBaseAvailable.equalsIgnoreCase("true"); - @Test public void testNoREC() { - assumeTrue(oceanBaseIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.OCEANBASE_ENV)); assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", "--database-prefix", "norecdb", "--num-queries", TestConfig.NUM_QUERIES, diff --git a/test/sqlancer/dbms/TestOceanBasePQS.java b/test/sqlancer/dbms/TestOceanBasePQS.java index adb2007d5..404977398 100644 --- a/test/sqlancer/dbms/TestOceanBasePQS.java +++ b/test/sqlancer/dbms/TestOceanBasePQS.java @@ -9,12 +9,9 @@ public class TestOceanBasePQS { - String oceanBaseAvailable = System.getenv("OCEANBASE_AVAILABLE"); - boolean oceanBaseIsAvailable = oceanBaseAvailable != null && oceanBaseAvailable.equalsIgnoreCase("true"); - @Test public void testPQS() { - assumeTrue(oceanBaseIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.OCEANBASE_ENV)); assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", "--random-string-generation", "ALPHANUMERIC_SPECIALCHAR", diff --git a/test/sqlancer/dbms/TestOceanBaseTLP.java b/test/sqlancer/dbms/TestOceanBaseTLP.java index fe84118e6..cf86c2078 100644 --- a/test/sqlancer/dbms/TestOceanBaseTLP.java +++ b/test/sqlancer/dbms/TestOceanBaseTLP.java @@ -9,12 +9,9 @@ public class TestOceanBaseTLP { - String oceanBaseAvailable = System.getenv("OCEANBASE_AVAILABLE"); - boolean oceanBaseIsAvailable = oceanBaseAvailable != null && oceanBaseAvailable.equalsIgnoreCase("true"); - @Test public void testTLP() { - assumeTrue(oceanBaseIsAvailable); + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.OCEANBASE_ENV)); assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "4", "--database-prefix", "tlpdb", "--num-queries", TestConfig.NUM_QUERIES, diff --git a/test/sqlancer/dbms/TestPostgres.java b/test/sqlancer/dbms/TestPostgres.java deleted file mode 100644 index 5eeb1f858..000000000 --- a/test/sqlancer/dbms/TestPostgres.java +++ /dev/null @@ -1,35 +0,0 @@ -package sqlancer.dbms; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assumptions.assumeTrue; - -import org.junit.jupiter.api.Test; - -import sqlancer.Main; - -public class TestPostgres { - - String postgresAvailable = System.getenv("POSTGRES_AVAILABLE"); - boolean postgresIsAvailable = postgresAvailable != null && postgresAvailable.equalsIgnoreCase("true"); - - @Test - public void testPostgres() { - assumeTrue(postgresIsAvailable); - assertEquals(0, - Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, - "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "postgres", "--test-collations", - "false" })); - } - - @Test - public void testPQS() { - assumeTrue(postgresIsAvailable); - assertEquals(0, - Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, - "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "--random-string-generation", - "ALPHANUMERIC_SPECIALCHAR", "--database-prefix", - "pqsdb" /* Workaround for connections not being closed */, "postgres", "--test-collations", - "false", "--oracle", "pqs" })); - } - -} diff --git a/test/sqlancer/dbms/TestPostgresCERT.java b/test/sqlancer/dbms/TestPostgresCERT.java new file mode 100644 index 000000000..afa5e3801 --- /dev/null +++ b/test/sqlancer/dbms/TestPostgresCERT.java @@ -0,0 +1,20 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestPostgresCERT { + + @Test + public void testCERT() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.POSTGRES_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "postgres", "--test-collations", + "false", "--oracle", "CERT" })); + } +} diff --git a/test/sqlancer/dbms/TestPostgresNoREC.java b/test/sqlancer/dbms/TestPostgresNoREC.java new file mode 100644 index 000000000..8b9f00d48 --- /dev/null +++ b/test/sqlancer/dbms/TestPostgresNoREC.java @@ -0,0 +1,20 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestPostgresNoREC { + + @Test + public void testNoREC() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.POSTGRES_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "postgres", "--test-collations", + "false", "--oracle", "NOREC" })); + } +} diff --git a/test/sqlancer/dbms/TestPostgresPQS.java b/test/sqlancer/dbms/TestPostgresPQS.java new file mode 100644 index 000000000..f6a37f533 --- /dev/null +++ b/test/sqlancer/dbms/TestPostgresPQS.java @@ -0,0 +1,21 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestPostgresPQS { + + @Test + public void testPQS() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.POSTGRES_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "--random-string-generation", + "ALPHANUMERIC_SPECIALCHAR", "postgres", "--test-collations", "false", "--oracle", "pqs" })); + } + +} diff --git a/test/sqlancer/dbms/TestPostgresTLP.java b/test/sqlancer/dbms/TestPostgresTLP.java new file mode 100644 index 000000000..5bd722991 --- /dev/null +++ b/test/sqlancer/dbms/TestPostgresTLP.java @@ -0,0 +1,20 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestPostgresTLP { + + @Test + public void testTLP() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.POSTGRES_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "postgres", "--test-collations", + "false" })); + } +} diff --git a/test/sqlancer/dbms/TestPrestoNoREC.java b/test/sqlancer/dbms/TestPrestoNoREC.java new file mode 100644 index 000000000..e38e90483 --- /dev/null +++ b/test/sqlancer/dbms/TestPrestoNoREC.java @@ -0,0 +1,17 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestPrestoNoREC { + @Test + public void testPrestoNoREC() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.PRESTO_ENV)); + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "presto", "--oracle", "NOREC" })); + } +} diff --git a/test/sqlancer/dbms/TestPrestoTLP.java b/test/sqlancer/dbms/TestPrestoTLP.java new file mode 100644 index 000000000..6ffefc333 --- /dev/null +++ b/test/sqlancer/dbms/TestPrestoTLP.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestPrestoTLP { + @Test + public void testPrestoTLP() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.PRESTO_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--num-queries", TestConfig.NUM_QUERIES, "--validate-result-size-only", + "true", "--canonicalize-sql-strings", "false", "presto", "--oracle", "QUERY_PARTITIONING" })); + } +} diff --git a/test/sqlancer/dbms/TestSQLiteCODDTest.java b/test/sqlancer/dbms/TestSQLiteCODDTest.java new file mode 100644 index 000000000..c1948a9b8 --- /dev/null +++ b/test/sqlancer/dbms/TestSQLiteCODDTest.java @@ -0,0 +1,16 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestSQLiteCODDTest { + + @Test + public void testSqliteCODDTest() { + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "sqlite3", "--oracle", "CODDTest" })); + } +} diff --git a/test/sqlancer/dbms/TestSQLiteNoREC.java b/test/sqlancer/dbms/TestSQLiteNoREC.java new file mode 100644 index 000000000..aa7741659 --- /dev/null +++ b/test/sqlancer/dbms/TestSQLiteNoREC.java @@ -0,0 +1,17 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestSQLiteNoREC { + + @Test + public void testSqliteNoREC() { + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "sqlite3", "--oracle", "NoREC" })); + } + +} diff --git a/test/sqlancer/dbms/TestSQLiteTLP.java b/test/sqlancer/dbms/TestSQLiteTLP.java new file mode 100644 index 000000000..7d90fd93c --- /dev/null +++ b/test/sqlancer/dbms/TestSQLiteTLP.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestSQLiteTLP { + + @Test + public void testSqliteTLP() { + // run with one thread due to multithreading issues, see https://github.com/sqlancer/sqlancer/pull/45 + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "sqlite3", "--oracle", + "QUERY_PARTITIONING" })); + } +} diff --git a/test/sqlancer/dbms/TestSparkTLP.java b/test/sqlancer/dbms/TestSparkTLP.java new file mode 100644 index 000000000..83302ceff --- /dev/null +++ b/test/sqlancer/dbms/TestSparkTLP.java @@ -0,0 +1,20 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestSparkTLP { + + @Test + public void testSparkTLPWhere() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.SPARK_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--canonicalize-sql-strings", "false", "--random-seed", "0", + "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "1", "--num-queries", + TestConfig.NUM_QUERIES, "spark", "--oracle", "TLPWhere" })); + } +} \ No newline at end of file diff --git a/test/sqlancer/dbms/TestTiDBCERT.java b/test/sqlancer/dbms/TestTiDBCERT.java new file mode 100644 index 000000000..748444870 --- /dev/null +++ b/test/sqlancer/dbms/TestTiDBCERT.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestTiDBCERT { + + @Test + public void testCERT() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.TIDB_ENV)); + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-queries", "4", "tidb", "--oracle", "CERT" })); + } + +} diff --git a/test/sqlancer/dbms/TestTiDBTLP.java b/test/sqlancer/dbms/TestTiDBTLP.java new file mode 100644 index 000000000..11cf7581e --- /dev/null +++ b/test/sqlancer/dbms/TestTiDBTLP.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestTiDBTLP { + + @Test + public void testTLP() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.TIDB_ENV)); + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-queries", "0", "tidb" })); + } + +} diff --git a/test/sqlancer/dbms/TestYCQL.java b/test/sqlancer/dbms/TestYCQL.java new file mode 100644 index 000000000..1198e2376 --- /dev/null +++ b/test/sqlancer/dbms/TestYCQL.java @@ -0,0 +1,18 @@ +package sqlancer.dbms; + +import org.junit.jupiter.api.Test; +import sqlancer.Main; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class TestYCQL { + @Test + public void testYCQL() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.YUGABYTE_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--username", + "cassandra", "--password", "cassandra", "--num-threads", "1", "--num-queries", + TestConfig.NUM_QUERIES, "ycql")); + } +} diff --git a/test/sqlancer/dbms/TestYSQLNoREC.java b/test/sqlancer/dbms/TestYSQLNoREC.java new file mode 100644 index 000000000..51f0487cc --- /dev/null +++ b/test/sqlancer/dbms/TestYSQLNoREC.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestYSQLNoREC { + @Test + public void testYSQLNoREC() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.YUGABYTE_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--username", + "yugabyte", "--password", "yugabyte", "--num-threads", "1", "--num-queries", + TestConfig.NUM_QUERIES, "ysql", "--oracle", "NOREC")); + } +} diff --git a/test/sqlancer/dbms/TestYSQLPQS.java b/test/sqlancer/dbms/TestYSQLPQS.java new file mode 100644 index 000000000..6a21b6cb2 --- /dev/null +++ b/test/sqlancer/dbms/TestYSQLPQS.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestYSQLPQS { + @Test + public void testYSQLPQS() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.YUGABYTE_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--username", + "yugabyte", "--password", "yugabyte", "--num-threads", "1", "--num-queries", + TestConfig.NUM_QUERIES, "ysql", "--oracle", "PQS")); + } +} diff --git a/test/sqlancer/dbms/TestYSQLTLP.java b/test/sqlancer/dbms/TestYSQLTLP.java new file mode 100644 index 000000000..fdcdf372a --- /dev/null +++ b/test/sqlancer/dbms/TestYSQLTLP.java @@ -0,0 +1,19 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestYSQLTLP { + @Test + public void testYSQLTLP() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.YUGABYTE_ENV)); + assertEquals(0, + Main.executeMain("--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, "--username", + "yugabyte", "--password", "yugabyte", "--num-threads", "1", "--num-queries", + TestConfig.NUM_QUERIES, "ysql", "--oracle", "QUERY_PARTITIONING")); + } +} diff --git a/test/sqlancer/dqp/mariadb/TestMariaDBDQP.java b/test/sqlancer/dqp/mariadb/TestMariaDBDQP.java new file mode 100644 index 000000000..47dabc225 --- /dev/null +++ b/test/sqlancer/dqp/mariadb/TestMariaDBDQP.java @@ -0,0 +1,22 @@ +package sqlancer.dqp.mariadb; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.dbms.TestConfig; + +public class TestMariaDBDQP { + + @Test + public void testMariaDBDQPMethod() { + String mariadb = System.getenv("MARIADB_AVAILABLE"); + boolean mariadbIsAvailable = mariadb != null && mariadb.equalsIgnoreCase("true"); + assumeTrue(mariadbIsAvailable); + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "mariadb", "--oracle", "DQP" })); + } + +} diff --git a/test/sqlancer/dqp/mysql/TestMySQLDQP.java b/test/sqlancer/dqp/mysql/TestMySQLDQP.java new file mode 100644 index 000000000..38e5eef61 --- /dev/null +++ b/test/sqlancer/dqp/mysql/TestMySQLDQP.java @@ -0,0 +1,22 @@ +package sqlancer.dqp.mysql; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.dbms.TestConfig; + +public class TestMySQLDQP { + + @Test + public void testmysqlQPG() { + String mysql = System.getenv("MYSQL_AVAILABLE"); + boolean mysqlIsAvailable = mysql != null && mysql.equalsIgnoreCase("true"); + assumeTrue(mysqlIsAvailable); + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "mysql", "--oracle", "DQP" })); + } + +} diff --git a/test/sqlancer/dbms/TestTiDB.java b/test/sqlancer/dqp/tidb/TestTiDBDQP.java similarity index 68% rename from test/sqlancer/dbms/TestTiDB.java rename to test/sqlancer/dqp/tidb/TestTiDBDQP.java index b97e59273..300dfde34 100644 --- a/test/sqlancer/dbms/TestTiDB.java +++ b/test/sqlancer/dqp/tidb/TestTiDBDQP.java @@ -1,4 +1,4 @@ -package sqlancer.dbms; +package sqlancer.dqp.tidb; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assumptions.assumeTrue; @@ -6,16 +6,17 @@ import org.junit.jupiter.api.Test; import sqlancer.Main; +import sqlancer.dbms.TestConfig; -public class TestTiDB { +public class TestTiDBDQP { @Test - public void testMySQL() { + public void testTiDBQPG() { String tiDB = System.getenv("TIDB_AVAILABLE"); boolean tiDBIsAvailable = tiDB != null && tiDB.equalsIgnoreCase("true"); assumeTrue(tiDBIsAvailable); assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, - "--num-queries", "0", "tidb" })); + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "tidb", "--oracle", "DQP" })); } } diff --git a/test/sqlancer/mysql/MySQLToStringVisitorTest.java b/test/sqlancer/mysql/MySQLToStringVisitorTest.java new file mode 100644 index 000000000..3f991e3f1 --- /dev/null +++ b/test/sqlancer/mysql/MySQLToStringVisitorTest.java @@ -0,0 +1,77 @@ +package sqlancer.mysql; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import sqlancer.mysql.ast.MySQLAggregate; +import sqlancer.mysql.ast.MySQLCaseOperator; +import sqlancer.mysql.ast.MySQLColumnReference; +import sqlancer.mysql.ast.MySQLConstant; +import sqlancer.mysql.ast.MySQLExpression; +import sqlancer.mysql.ast.MySQLConstant.MySQLIntConstant; + +public class MySQLToStringVisitorTest { + + @Test + void visitAggregateToString() { + MySQLSchema.MySQLColumn aCol = new MySQLSchema.MySQLColumn("a", MySQLSchema.MySQLDataType.INT, false, 0); + MySQLColumnReference aRef = new MySQLColumnReference(aCol, MySQLConstant.createNullConstant()); + + MySQLAggregate aggrCount = new MySQLAggregate(List.of(aRef), MySQLAggregate.MySQLAggregateFunction.COUNT); + assertEquals("COUNT(a)", MySQLVisitor.asString(aggrCount)); + + MySQLAggregate aggrSum = new MySQLAggregate(List.of(aRef), MySQLAggregate.MySQLAggregateFunction.SUM); + assertEquals("SUM(a)", MySQLVisitor.asString(aggrSum)); + + MySQLAggregate aggrMin = new MySQLAggregate(List.of(aRef), MySQLAggregate.MySQLAggregateFunction.MIN); + assertEquals("MIN(a)", MySQLVisitor.asString(aggrMin)); + + MySQLAggregate aggrMax = new MySQLAggregate(List.of(aRef), MySQLAggregate.MySQLAggregateFunction.MAX); + assertEquals("MAX(a)", MySQLVisitor.asString(aggrMax)); + } + + @Test + void visitAggregateWithDistinctToString() { + MySQLSchema.MySQLColumn aCol = new MySQLSchema.MySQLColumn("a", MySQLSchema.MySQLDataType.INT, false, 0); + MySQLColumnReference aRef = new MySQLColumnReference(aCol, MySQLConstant.createNullConstant()); + + MySQLAggregate aggrCountDistinct = new MySQLAggregate(List.of(aRef), + MySQLAggregate.MySQLAggregateFunction.COUNT_DISTINCT); + assertEquals("COUNT(DISTINCT a)", MySQLVisitor.asString(aggrCountDistinct)); + + MySQLAggregate aggrSumDistinct = new MySQLAggregate(List.of(aRef), + MySQLAggregate.MySQLAggregateFunction.SUM_DISTINCT); + assertEquals("SUM(DISTINCT a)", MySQLVisitor.asString(aggrSumDistinct)); + + MySQLAggregate aggrMinDistinct = new MySQLAggregate(List.of(aRef), + MySQLAggregate.MySQLAggregateFunction.MIN_DISTINCT); + assertEquals("MIN(DISTINCT a)", MySQLVisitor.asString(aggrMinDistinct)); + + MySQLAggregate aggrMaxDistinct = new MySQLAggregate(List.of(aRef), + MySQLAggregate.MySQLAggregateFunction.MAX_DISTINCT); + assertEquals("MAX(DISTINCT a)", MySQLVisitor.asString(aggrMaxDistinct)); + } + + @Test + void visitCaseWhenToString() { + MySQLSchema.MySQLColumn aCol = new MySQLSchema.MySQLColumn("a", MySQLSchema.MySQLDataType.INT, false, 0); + MySQLColumnReference switchExpr = new MySQLColumnReference(aCol, MySQLConstant.createNullConstant()); + List whenExprs = List.of(MySQLIntConstant.createIntConstant(1), + MySQLIntConstant.createIntConstant(2)); + List thenExprs = List.of(MySQLIntConstant.createIntConstant(11), + MySQLIntConstant.createIntConstant(22)); + MySQLConstant elseExpr = MySQLConstant.createIntConstant(0); + + assertEquals("(CASE a WHEN 1 THEN 11 WHEN 2 THEN 22 ELSE 0 END)", + MySQLVisitor.asString(new MySQLCaseOperator(switchExpr, whenExprs, thenExprs, elseExpr))); + assertEquals("(CASE WHEN 1 THEN 11 WHEN 2 THEN 22 ELSE 0 END)", + MySQLVisitor.asString(new MySQLCaseOperator(null, whenExprs, thenExprs, elseExpr))); + assertEquals("(CASE a WHEN 1 THEN 11 WHEN 2 THEN 22 END)", + MySQLVisitor.asString(new MySQLCaseOperator(switchExpr, whenExprs, thenExprs, null))); + assertEquals("(CASE WHEN 1 THEN 11 WHEN 2 THEN 22 END)", + MySQLVisitor.asString(new MySQLCaseOperator(null, whenExprs, thenExprs, null))); + } +} diff --git a/test/sqlancer/mysql/ast/MySQLCaseOperatorTest.java b/test/sqlancer/mysql/ast/MySQLCaseOperatorTest.java new file mode 100644 index 000000000..674ab027e --- /dev/null +++ b/test/sqlancer/mysql/ast/MySQLCaseOperatorTest.java @@ -0,0 +1,67 @@ +package sqlancer.mysql.ast; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import sqlancer.mysql.MySQLSchema; +import sqlancer.mysql.ast.MySQLConstant.MySQLIntConstant; + +public class MySQLCaseOperatorTest { + + @Test + void getExpectedValue_switchConditionMatchesWhen_ReturnsThen() { + MySQLSchema.MySQLColumn aCol = new MySQLSchema.MySQLColumn("a", MySQLSchema.MySQLDataType.INT, false, 0); + MySQLColumnReference switchExpr = new MySQLColumnReference(aCol, MySQLIntConstant.createIntConstant(1)); + List whenExprs = List.of(MySQLIntConstant.createIntConstant(1), + MySQLIntConstant.createIntConstant(2)); + List thenExprs = List.of(MySQLIntConstant.createIntConstant(11), + MySQLIntConstant.createIntConstant(22)); + MySQLConstant elseExpr = MySQLConstant.createIntConstant(0); + + MySQLCaseOperator caseOperator = new MySQLCaseOperator(switchExpr, whenExprs, thenExprs, elseExpr); + + assertEquals(11, caseOperator.getExpectedValue().getInt()); + } + + @Test + void getExpectedValue_switchConditionHasNoMatches_ReturnsElse() { + MySQLSchema.MySQLColumn aCol = new MySQLSchema.MySQLColumn("a", MySQLSchema.MySQLDataType.INT, false, 0); + MySQLColumnReference switchExpr = new MySQLColumnReference(aCol, MySQLIntConstant.createNullConstant()); + List whenExprs = List.of(MySQLIntConstant.createIntConstant(1), + MySQLIntConstant.createIntConstant(2)); + List thenExprs = List.of(MySQLIntConstant.createIntConstant(11), + MySQLIntConstant.createIntConstant(22)); + MySQLConstant elseExpr = MySQLConstant.createIntConstant(0); + + assertEquals(0, new MySQLCaseOperator(switchExpr, whenExprs, thenExprs, elseExpr).getExpectedValue().getInt()); + assertTrue(new MySQLCaseOperator(switchExpr, whenExprs, thenExprs, null).getExpectedValue().isNull()); + } + + @Test + void getExpectedValue_whenTrue_ReturnsThen() { + List whenExprs = List.of(MySQLIntConstant.createIntConstant(1), + MySQLIntConstant.createIntConstant(2)); + List thenExprs = List.of(MySQLIntConstant.createIntConstant(11), + MySQLIntConstant.createIntConstant(22)); + MySQLConstant elseExpr = MySQLConstant.createIntConstant(0); + MySQLCaseOperator caseOperator = new MySQLCaseOperator(null, whenExprs, thenExprs, elseExpr); + + assertEquals(11, caseOperator.getExpectedValue().getInt()); + } + + @Test + void getExpectedValue_whenAllFalse_ReturnsElse() { + List whenExprs = List.of(MySQLIntConstant.createBoolean(false), + MySQLIntConstant.createBoolean(false)); + List thenExprs = List.of(MySQLIntConstant.createIntConstant(11), + MySQLIntConstant.createIntConstant(22)); + MySQLConstant elseExpr = MySQLConstant.createIntConstant(0); + + assertEquals(0, new MySQLCaseOperator(null, whenExprs, thenExprs, elseExpr).getExpectedValue().getInt()); + assertTrue(new MySQLCaseOperator(null, whenExprs, thenExprs, null).getExpectedValue().isNull()); + } +} diff --git a/test/sqlancer/qpg/cockroachdb/TestCockroachDBQPG.java b/test/sqlancer/qpg/cockroachdb/TestCockroachDBQPG.java new file mode 100644 index 000000000..b3ba57a10 --- /dev/null +++ b/test/sqlancer/qpg/cockroachdb/TestCockroachDBQPG.java @@ -0,0 +1,24 @@ +package sqlancer.qpg.cockroachdb; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.dbms.TestConfig; + +public class TestCockroachDBQPG { + + @Test + public void testCockroachDBQPG() { + String cockroachDB = System.getenv("COCKROACHDB_AVAILABLE"); + boolean cockroachDBIsAvailable = cockroachDB != null && cockroachDB.equalsIgnoreCase("true"); + assumeTrue(cockroachDBIsAvailable); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--qpg-enable", "true", "--num-queries", TestConfig.NUM_QUERIES, + "cockroachdb", "--oracle", "QUERY_PARTITIONING" })); + } + +} diff --git a/test/sqlancer/qpg/cockroachdb/TestCockroachDBQueryPlan.java b/test/sqlancer/qpg/cockroachdb/TestCockroachDBQueryPlan.java new file mode 100644 index 000000000..c858ee190 --- /dev/null +++ b/test/sqlancer/qpg/cockroachdb/TestCockroachDBQueryPlan.java @@ -0,0 +1,47 @@ +package sqlancer.qpg.cockroachdb; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.MainOptions; +import sqlancer.SQLConnection; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.cockroachdb.CockroachDBOptions; +import sqlancer.cockroachdb.CockroachDBProvider; +import sqlancer.cockroachdb.CockroachDBProvider.CockroachDBGlobalState; + +public class TestCockroachDBQueryPlan { + + @Test + void testCockroachDBQueryPlan() throws Exception { + String cockroachDB = System.getenv("COCKROACHDB_AVAILABLE"); + boolean cockroachDBIsAvailable = cockroachDB != null && cockroachDB.equalsIgnoreCase("true"); + assumeTrue(cockroachDBIsAvailable); + + String databaseName = "cockroachdb"; + CockroachDBProvider provider = new CockroachDBProvider(); + CockroachDBGlobalState state = provider.getGlobalStateClass().getDeclaredConstructor().newInstance(); + CockroachDBOptions cockroachdbOption = provider.getOptionClass().getDeclaredConstructor().newInstance(); + state.setDbmsSpecificOptions(cockroachdbOption); + state.setDatabaseName(databaseName); + MainOptions options = new MainOptions(); + state.setMainOptions(options); + state.setState(provider.getStateToReproduce(databaseName)); + SQLConnection con = provider.createDatabase(state); + state.setConnection(con); + Main.StateLogger logger = new Main.StateLogger(databaseName, provider, options); + state.setStateLogger(logger); + + SQLQueryAdapter q = new SQLQueryAdapter("CREATE TABLE t1(a INT, b INT);", true); + q.execute(state); + q = new SQLQueryAdapter("CREATE TABLE t2(c INT);", true); + q.execute(state); + String queryPlan = provider.getQueryPlan("SELECT * FROM t1 RIGHT JOIN t2 ON a<>0;", state); + + assertEquals("left-join (cross);scan t2;select;scan t1;filters;filters (true);", queryPlan); + } + +} diff --git a/test/sqlancer/qpg/materialize/TestMaterializeQPG.java b/test/sqlancer/qpg/materialize/TestMaterializeQPG.java new file mode 100644 index 000000000..431468cee --- /dev/null +++ b/test/sqlancer/qpg/materialize/TestMaterializeQPG.java @@ -0,0 +1,23 @@ +package sqlancer.qpg.materialize; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.dbms.TestConfig; + +public class TestMaterializeQPG { + + @Test + public void testMaterializeQPG() { + String materialize = System.getenv("MATERIALIZE_AVAILABLE"); + boolean materializeIsAvailable = materialize != null && materialize.equalsIgnoreCase("true"); + assumeTrue(materializeIsAvailable); + assertEquals(0, Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--qpg-enable", "true", "--num-queries", TestConfig.NUM_QUERIES, "--username", + "materialize", "materialize", "--oracle", "QUERY_PARTITIONING", "--set-max-tables-mvs", "true" })); + } + +} diff --git a/test/sqlancer/qpg/materialize/TestMaterializeQueryPlan.java b/test/sqlancer/qpg/materialize/TestMaterializeQueryPlan.java new file mode 100644 index 000000000..4d26bc08e --- /dev/null +++ b/test/sqlancer/qpg/materialize/TestMaterializeQueryPlan.java @@ -0,0 +1,49 @@ +package sqlancer.qpg.materialize; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.MainOptions; +import sqlancer.SQLConnection; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.materialize.MaterializeOptions; +import sqlancer.materialize.MaterializeGlobalState; +import sqlancer.materialize.MaterializeProvider; + +public class TestMaterializeQueryPlan { + + @Test + void testMaterializeQueryPlan() throws Exception { + String materialize = System.getenv("MATERIALIZE_AVAILABLE"); + boolean materializeIsAvailable = materialize != null && materialize.equalsIgnoreCase("true"); + assumeTrue(materializeIsAvailable); + + String databaseName = "queryplan"; + MaterializeProvider provider = new MaterializeProvider(); + MaterializeGlobalState state = provider.getGlobalStateClass().getDeclaredConstructor().newInstance(); + MaterializeOptions materializeOption = provider.getOptionClass().getDeclaredConstructor().newInstance(); + state.setDbmsSpecificOptions(materializeOption); + state.setDatabaseName(databaseName); + MainOptions options = new MainOptions(); + state.setMainOptions(options); + state.setState(provider.getStateToReproduce(databaseName)); + SQLConnection con = provider.createDatabase(state); + state.setConnection(con); + Main.StateLogger logger = new Main.StateLogger(databaseName, provider, options); + state.setStateLogger(logger); + + SQLQueryAdapter q = new SQLQueryAdapter("CREATE TABLE t1(a INT, b INT);", true); + q.execute(state); + q = new SQLQueryAdapter("CREATE TABLE t2(c INT);", true); + q.execute(state); + String queryPlan = provider.getQueryPlan("SELECT * FROM t1 RIGHT JOIN t2 ON a<>0;", state); + + assertEquals( + "With;ReadStorage queryplan.public.t1 // { arity: 2 };ReadStorage queryplan.public.t2 // { arity: 1 };Return // { arity: 3 };Union // { arity: 3 };Get l0 // { arity: 3 };Project (#2, #3, #0{c}) // { arity: 3 };Union // { arity: 1 };Negate // { arity: 1 };Project (#2{c}) // { arity: 1 };ReadStorage queryplan.public.t2 // { arity: 1 };ReadStorage queryplan.public.t2 // { arity: 1 };;Source queryplan.public.t1;Source queryplan.public.t2;;Target cluster: quickstart;", + queryPlan); + } + +} diff --git a/test/sqlancer/qpg/postgres/TestPostgresQPG.java b/test/sqlancer/qpg/postgres/TestPostgresQPG.java new file mode 100644 index 000000000..20ec3e91e --- /dev/null +++ b/test/sqlancer/qpg/postgres/TestPostgresQPG.java @@ -0,0 +1,22 @@ +package sqlancer.qpg.postgres; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.dbms.TestConfig; + +public class TestPostgresQPG { + + @Test + public void testPostgresQPG() { + String postgres = System.getenv("POSTGRES_AVAILABLE"); + boolean postgresIsAvailable = postgres != null && postgres.equalsIgnoreCase("true"); + assumeTrue(postgresIsAvailable); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "4", "--qpg-enable", "true", "--num-queries", TestConfig.NUM_QUERIES, + "--username", "postgres", "postgres", "--oracle", "NOREC" })); + } +} diff --git a/test/sqlancer/qpg/postgres/TestPostgresQueryPlan.java b/test/sqlancer/qpg/postgres/TestPostgresQueryPlan.java new file mode 100644 index 000000000..c883e87d8 --- /dev/null +++ b/test/sqlancer/qpg/postgres/TestPostgresQueryPlan.java @@ -0,0 +1,165 @@ +package sqlancer.qpg.postgres; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.MainOptions; +import sqlancer.SQLConnection; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.postgres.PostgresGlobalState; +import sqlancer.postgres.PostgresOptions; +import sqlancer.postgres.PostgresProvider; + +public class TestPostgresQueryPlan { + + @Test + void testPostgresQueryPlan() throws Exception { + String databaseName = "postgres"; + PostgresProvider provider = new PostgresProvider(); + PostgresGlobalState state = provider.getGlobalStateClass().getDeclaredConstructor().newInstance(); + PostgresOptions postgresOption = provider.getOptionClass().getDeclaredConstructor().newInstance(); + state.setDbmsSpecificOptions(postgresOption); + state.setDatabaseName(databaseName); + MainOptions options = new MainOptions(); + state.setMainOptions(options); + state.setState(provider.getStateToReproduce(databaseName)); + SQLConnection con = provider.createDatabase(state); + state.setConnection(con); + Main.StateLogger logger = new Main.StateLogger(databaseName, provider, options); + state.setStateLogger(logger); + + SQLQueryAdapter q = new SQLQueryAdapter("CREATE TABLE t1(a INT, b INT);", true); + q.execute(state); + q = new SQLQueryAdapter("CREATE TABLE t2(c INT);", true); + q.execute(state); + String queryPlan = provider.getQueryPlan("SELECT * FROM t1 RIGHT JOIN t2 ON a<>0;", state); + assertEquals("Nested Loop Seq Scan Materialize Seq Scan", queryPlan); + } + + @Test + void testFormatQueryPlan() throws Exception { + + PostgresProvider provider = new PostgresProvider(); + + String queryPlan = "[\n" + " {\n" + " \"Plan\": {\n" + " \"Node Type\": \"Aggregate\",\n" + + " \"Strategy\": \"Hashed\",\n" + " \"Partial Mode\": \"Simple\",\n" + + " \"Parallel Aware\": false,\n" + " \"Async Capable\": false,\n" + + " \"Startup Cost\": 62998.82,\n" + " \"Total Cost\": 63009.32,\n" + + " \"Plan Rows\": 1050,\n" + " \"Plan Width\": 4,\n" + " \"Output\": [\"t1.c0\"],\n" + + " \"Group Key\": [\"t1.c0\"],\n" + " \"Planned Partitions\": 0,\n" + " \"Plans\": [\n" + + " {\n" + " \"Node Type\": \"Append\",\n" + + " \"Parent Relationship\": \"Outer\",\n" + " \"Parallel Aware\": false,\n" + + " \"Async Capable\": false,\n" + " \"Startup Cost\": 27150.40,\n" + + " \"Total Cost\": 62996.20,\n" + " \"Plan Rows\": 1050,\n" + + " \"Plan Width\": 4,\n" + " \"Subplans Removed\": 0,\n" + " \"Plans\": [\n" + + " {\n" + " \"Node Type\": \"Group\",\n" + + " \"Parent Relationship\": \"Member\",\n" + " \"Parallel Aware\": false,\n" + + " \"Async Capable\": false,\n" + " \"Startup Cost\": 27150.40,\n" + + " \"Total Cost\": 62949.08,\n" + " \"Plan Rows\": 200,\n" + + " \"Plan Width\": 4,\n" + " \"Output\": [\"t1.c0\"],\n" + + " \"Group Key\": [\"t1.c0\"],\n" + " \"Plans\": [\n" + " {\n" + + " \"Node Type\": \"Gather Merge\",\n" + + " \"Parent Relationship\": \"Outer\",\n" + + " \"Parallel Aware\": false,\n" + " \"Async Capable\": false,\n" + + " \"Startup Cost\": 27150.40,\n" + " \"Total Cost\": 62948.08,\n" + + " \"Plan Rows\": 400,\n" + " \"Plan Width\": 4,\n" + + " \"Output\": [\"t1.c0\"],\n" + " \"Workers Planned\": 2,\n" + + " \"Plans\": [\n" + " {\n" + + " \"Node Type\": \"Group\",\n" + + " \"Parent Relationship\": \"Outer\",\n" + + " \"Parallel Aware\": false,\n" + + " \"Async Capable\": false,\n" + + " \"Startup Cost\": 26150.38,\n" + + " \"Total Cost\": 61901.89,\n" + " \"Plan Rows\": 200,\n" + + " \"Plan Width\": 4,\n" + " \"Output\": [\"t1.c0\"],\n" + + " \"Group Key\": [\"t1.c0\"],\n" + " \"Plans\": [\n" + + " {\n" + " \"Node Type\": \"Merge Join\",\n" + + " \"Parent Relationship\": \"Outer\",\n" + + " \"Parallel Aware\": false,\n" + + " \"Async Capable\": false,\n" + + " \"Join Type\": \"Inner\",\n" + + " \"Startup Cost\": 26150.38,\n" + + " \"Total Cost\": 56906.48,\n" + + " \"Plan Rows\": 1998164,\n" + + " \"Plan Width\": 4,\n" + + " \"Output\": [\"t1.c0\"],\n" + + " \"Inner Unique\": false,\n" + + " \"Merge Cond\": \"(t0.c0 = t1.c0)\",\n" + + " \"Plans\": [\n" + " {\n" + + " \"Node Type\": \"Sort\",\n" + + " \"Parent Relationship\": \"Outer\",\n" + + " \"Parallel Aware\": false,\n" + + " \"Async Capable\": false,\n" + + " \"Startup Cost\": 25970.60,\n" + + " \"Total Cost\": 26362.39,\n" + + " \"Plan Rows\": 156719,\n" + + " \"Plan Width\": 4,\n" + + " \"Output\": [\"t0.c0\"],\n" + + " \"Sort Key\": [\"t0.c0\"],\n" + + " \"Plans\": [\n" + " {\n" + + " \"Node Type\": \"Seq Scan\",\n" + + " \"Parent Relationship\": \"Outer\",\n" + + " \"Parallel Aware\": true,\n" + + " \"Async Capable\": false,\n" + + " \"Relation Name\": \"t0\",\n" + + " \"Schema\": \"public\",\n" + + " \"Alias\": \"t0\",\n" + + " \"Startup Cost\": 0.00,\n" + + " \"Total Cost\": 10301.95,\n" + + " \"Plan Rows\": 156719,\n" + + " \"Plan Width\": 4,\n" + + " \"Output\": [\"t0.c0\"],\n" + + " \"Filter\": \"(t0.c0 < 100)\"\n" + + " }\n" + " ]\n" + + " },\n" + " {\n" + + " \"Node Type\": \"Sort\",\n" + + " \"Parent Relationship\": \"Inner\",\n" + + " \"Parallel Aware\": false,\n" + + " \"Async Capable\": false,\n" + + " \"Startup Cost\": 179.78,\n" + + " \"Total Cost\": 186.16,\n" + + " \"Plan Rows\": 2550,\n" + + " \"Plan Width\": 4,\n" + + " \"Output\": [\"t1.c0\"],\n" + + " \"Sort Key\": [\"t1.c0\"],\n" + + " \"Plans\": [\n" + " {\n" + + " \"Node Type\": \"Seq Scan\",\n" + + " \"Parent Relationship\": \"Outer\",\n" + + " \"Parallel Aware\": false,\n" + + " \"Async Capable\": false,\n" + + " \"Relation Name\": \"t1\",\n" + + " \"Schema\": \"public\",\n" + + " \"Alias\": \"t1\",\n" + + " \"Startup Cost\": 0.00,\n" + + " \"Total Cost\": 35.50,\n" + + " \"Plan Rows\": 2550,\n" + + " \"Plan Width\": 4,\n" + + " \"Output\": [\"t1.c0\"]\n" + " }\n" + + " ]\n" + " }\n" + + " ]\n" + " }\n" + " ]\n" + + " }\n" + " ]\n" + " }\n" + " ]\n" + + " },\n" + " {\n" + " \"Node Type\": \"Bitmap Heap Scan\",\n" + + " \"Parent Relationship\": \"Member\",\n" + " \"Parallel Aware\": false,\n" + + " \"Async Capable\": false,\n" + " \"Relation Name\": \"t2\",\n" + + " \"Schema\": \"public\",\n" + " \"Alias\": \"t2\",\n" + + " \"Startup Cost\": 10.74,\n" + " \"Total Cost\": 31.37,\n" + + " \"Plan Rows\": 850,\n" + " \"Plan Width\": 4,\n" + + " \"Output\": [\"t2.c0\"],\n" + " \"Recheck Cond\": \"(t2.c0 < 10)\",\n" + + " \"Plans\": [\n" + " {\n" + + " \"Node Type\": \"Bitmap Index Scan\",\n" + + " \"Parent Relationship\": \"Outer\",\n" + + " \"Parallel Aware\": false,\n" + " \"Async Capable\": false,\n" + + " \"Index Name\": \"t2_pkey\",\n" + " \"Startup Cost\": 0.00,\n" + + " \"Total Cost\": 10.53,\n" + " \"Plan Rows\": 850,\n" + + " \"Plan Width\": 0,\n" + " \"Index Cond\": \"(t2.c0 < 10)\"\n" + + " }\n" + " ]\n" + " }\n" + " ]\n" + " }\n" + + " ]\n" + " },\n" + " \"Planning Time\": 1.954\n" + " }\n" + "]\n"; + + String formatedQueryPlan = provider.formatQueryPlan(queryPlan); + assertEquals( + "Aggregate Append Group Bitmap Heap Scan Gather Merge Bitmap Index Scan Group Merge Join Sort Sort Seq Scan Seq Scan", + formatedQueryPlan); + } + +} diff --git a/test/sqlancer/qpg/sqlite/TestSQLiteQPG.java b/test/sqlancer/qpg/sqlite/TestSQLiteQPG.java new file mode 100644 index 000000000..488037237 --- /dev/null +++ b/test/sqlancer/qpg/sqlite/TestSQLiteQPG.java @@ -0,0 +1,23 @@ +package sqlancer.qpg.sqlite; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.dbms.TestConfig; + +public class TestSQLiteQPG { + + @Test + public void testSqliteQPG() { + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--num-queries", TestConfig.NUM_QUERIES, "--random-string-generation", + "ALPHANUMERIC_SPECIALCHAR", "--database-prefix", + "pqsdb" /* Workaround for connections not being closed */, "--qpg-enable", "true", "sqlite3", + "--oracle", "NoREC", "--test-fts", "false", "--test-rtree", "false", "--test-check-constraints", + "false", "--test-in-operator", "false" })); + } + +} diff --git a/test/sqlancer/qpg/sqlite/TestSQLiteQueryPlan.java b/test/sqlancer/qpg/sqlite/TestSQLiteQueryPlan.java new file mode 100644 index 000000000..7672f7edf --- /dev/null +++ b/test/sqlancer/qpg/sqlite/TestSQLiteQueryPlan.java @@ -0,0 +1,41 @@ +package sqlancer.qpg.sqlite; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.MainOptions; +import sqlancer.SQLConnection; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.sqlite3.SQLite3GlobalState; +import sqlancer.sqlite3.SQLite3Options; +import sqlancer.sqlite3.SQLite3Provider; + +public class TestSQLiteQueryPlan { + + @Test + void testSQLiteQueryPlan() throws Exception { + String databaseName = "sqlite"; + SQLite3Provider provider = new SQLite3Provider(); + SQLite3GlobalState state = provider.getGlobalStateClass().getDeclaredConstructor().newInstance(); + SQLite3Options sqlite3Option = provider.getOptionClass().getDeclaredConstructor().newInstance(); + state.setDbmsSpecificOptions(sqlite3Option); + state.setDatabaseName(databaseName); + SQLConnection con = provider.createDatabase(state); + state.setConnection(con); + MainOptions options = new MainOptions(); + state.setMainOptions(options); + Main.StateLogger logger = new Main.StateLogger(databaseName, provider, options); + state.setStateLogger(logger); + + SQLQueryAdapter q = new SQLQueryAdapter("CREATE TABLE t1(a INT, b INT);", true); + q.execute(state); + q = new SQLQueryAdapter("CREATE TABLE t2(c INT);", true); + q.execute(state); + String queryPlan = provider.getQueryPlan("SELECT * FROM t1 RIGHT JOIN t2 ON a<>0;", state); + + assertEquals("SCAN t1;SCAN t2;RIGHT-JOIN t2;SCAN t2;", queryPlan); + } + +} diff --git a/test/sqlancer/qpg/tidb/TestTiDBQPG.java b/test/sqlancer/qpg/tidb/TestTiDBQPG.java new file mode 100644 index 000000000..66a43a41a --- /dev/null +++ b/test/sqlancer/qpg/tidb/TestTiDBQPG.java @@ -0,0 +1,24 @@ +package sqlancer.qpg.tidb; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.dbms.TestConfig; + +public class TestTiDBQPG { + + @Test + public void testTiDBQPG() { + String tiDB = System.getenv("TIDB_AVAILABLE"); + boolean tiDBIsAvailable = tiDB != null && tiDB.equalsIgnoreCase("true"); + assumeTrue(tiDBIsAvailable); + assertEquals(0, + Main.executeMain(new String[] { "--random-seed", "0", "--timeout-seconds", TestConfig.SECONDS, + "--num-threads", "1", "--qpg-enable", "true", "--num-queries", TestConfig.NUM_QUERIES, "tidb", + "--oracle", "QUERY_PARTITIONING" })); + } + +} diff --git a/test/sqlancer/qpg/tidb/TestTiDBQueryPlan.java b/test/sqlancer/qpg/tidb/TestTiDBQueryPlan.java new file mode 100644 index 000000000..d07bebe61 --- /dev/null +++ b/test/sqlancer/qpg/tidb/TestTiDBQueryPlan.java @@ -0,0 +1,49 @@ +package sqlancer.qpg.tidb; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; +import sqlancer.MainOptions; +import sqlancer.SQLConnection; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.tidb.TiDBOptions; +import sqlancer.tidb.TiDBProvider; +import sqlancer.tidb.TiDBProvider.TiDBGlobalState; + +public class TestTiDBQueryPlan { + + @Test + void testTiDBQueryPlan() throws Exception { + String tiDB = System.getenv("TIDB_AVAILABLE"); + boolean tiDBIsAvailable = tiDB != null && tiDB.equalsIgnoreCase("true"); + assumeTrue(tiDBIsAvailable); + + String databaseName = "tidb"; + TiDBProvider provider = new TiDBProvider(); + TiDBGlobalState state = provider.getGlobalStateClass().getDeclaredConstructor().newInstance(); + TiDBOptions TiDBOption = provider.getOptionClass().getDeclaredConstructor().newInstance(); + state.setDbmsSpecificOptions(TiDBOption); + state.setDatabaseName(databaseName); + MainOptions options = new MainOptions(); + state.setMainOptions(options); + Main.StateLogger logger = new Main.StateLogger(databaseName, provider, options); + state.setStateLogger(logger); + state.setState(provider.getStateToReproduce(databaseName)); + SQLConnection con = provider.createDatabase(state); + state.setConnection(con); + + SQLQueryAdapter q = new SQLQueryAdapter("CREATE TABLE t1(a INT, b INT);", true); + q.execute(state); + q = new SQLQueryAdapter("CREATE TABLE t2(c INT);", true); + q.execute(state); + String queryPlan = provider.getQueryPlan("SELECT * FROM t1 RIGHT JOIN t2 ON a<>0;", state); + + assertEquals( + "HashJoin_7;TableReader_10(Build);Selection_9;TableFullScan_8;TableReader_12(Probe);TableFullScan_11;", + queryPlan); + } + +} diff --git a/test/sqlancer/reducer/TestASTBasedReducer.java b/test/sqlancer/reducer/TestASTBasedReducer.java new file mode 100644 index 000000000..549bbe84d --- /dev/null +++ b/test/sqlancer/reducer/TestASTBasedReducer.java @@ -0,0 +1,267 @@ +package sqlancer.reducer; + +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import org.junit.jupiter.api.Test; +import sqlancer.common.query.Query; + +import java.util.List; +import java.util.function.Function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class TestASTBasedReducer { + @Test + void testRemovingClauses() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStr = { + "SELECT DISTINCT * FROM v0 WHERE ((v0.rowid || ( (v0.c + v0.d) < 200 && v0.c >= 100) || 114514)OR(((v0.c0)||(1529686005)))) UNION SELECT DISTINCT * FROM v0 WHERE (NOT ((v0.rowid)OR(((v0.c0)||(1529686005))))) UNION SELECT DISTINCT * FROM v0 WHERE ((((v0.rowid)OR(((v0.c0)||(1529686005))))) IS NULL)" }; + env.setInitialStatementsFromStrings(List.of(queriesStr)); + env.setBugInducingCondition(statements -> { + String queriesString = TestEnvironment.getQueriesString(statements); + try { + CCJSqlParserUtil.parse(queriesString); + } catch (JSQLParserException e) { + return false; + } + return queriesString.contains("&&"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + String outcome = TestEnvironment.getQueriesString(reducedResult); + assertEquals(outcome, "SELECT * FROM v0 WHERE v0.c && v0.c;"); + } + + @Test + void testReducingMultipleTokensToOne() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStr = { + "SELECT DISTINCT row_id, c FROM v0 WHERE ((v0.rowid || (v0.c < 200 && v0.c >= 100) || 114514)OR(((v0.c0)||(1529686005)))) UNION SELECT DISTINCT * FROM v0 WHERE (NOT ((v0.rowid)OR(((v0.c0)||(1529686005))))) UNION SELECT DISTINCT * FROM v0 WHERE ((((v0.rowid)OR(((v0.c0)||(1529686005))))) IS NULL)" }; + env.setInitialStatementsFromStrings(List.of(queriesStr)); + env.setBugInducingCondition(statements -> { + String queriesString = TestEnvironment.getQueriesString(statements); + try { + CCJSqlParserUtil.parse(queriesString); + } catch (JSQLParserException e) { + return false; + } + return queriesString.contains("||"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals(TestEnvironment.getQueriesString(reducedResult), "SELECT row_id FROM v0 WHERE v0.rowid || 0;"); + } + + @Test + void testMultipleStatements() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStrs = { + "SELECT DISTINCT row_id, c FROM v0 WHERE ((v0.rowid || (v0.c < 200 && v0.c >= 100) || 114514)OR(((v0.c0)||(1529686005)))) UNION SELECT DISTINCT * FROM v0 WHERE (NOT ((v0.rowid)OR(((v0.c0)||(1529686005))))) UNION SELECT DISTINCT * FROM v0 WHERE ((((v0.rowid)OR(((v0.c0)||(1529686005))))) IS NULL)", + "SELECT DISTINCT row_id, c FROM v0 WHERE ((v0.rowid || (v0.c < 200 && v0.c >= 100) || 114514)OR(((v0.c0)||(1529686005)))) UNION SELECT DISTINCT * FROM v0 WHERE (NOT ((v0.rowid)OR(((v0.c0)||(1529686005))))) UNION SELECT DISTINCT * FROM v0 WHERE ((((v0.rowid)OR(((v0.c0)||(1529686005))))) IS NULL)" }; + env.setInitialStatementsFromStrings(List.of(queriesStrs)); + env.setBugInducingCondition(statements -> { + String queriesString = TestEnvironment.getQueriesString(statements); + try { + for (Query s : statements) { + CCJSqlParserUtil.parse(s.getQueryString()); + } + } catch (JSQLParserException e) { + return false; + } + + return queriesString.toUpperCase().contains("UNION"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals(TestEnvironment.getQueriesString(reducedResult), + "SELECT row_id FROM v0;\nSELECT row_id FROM v0 UNION SELECT * FROM v0;"); + } + + @Test + void removeJoins() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStrs = { "SELECT * FROM t0, t1, t2, t3, t4 Where t2.val = t1.val" }; + env.setInitialStatementsFromStrings(List.of(queriesStrs)); + env.setBugInducingCondition(statements -> { + try { + for (Query s : statements) { + CCJSqlParserUtil.parse(s.getQueryString()); + } + } catch (JSQLParserException e) { + return false; + } + String queriesString = TestEnvironment.getQueriesString(statements); + if (!queriesString.contains("WHERE")) { + return false; + } + String[] split = queriesString.split("WHERE"); + String columns = split[0]; + String condition = split[1]; + return columns.contains("t1") && condition.contains("t1"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals(TestEnvironment.getQueriesString(reducedResult), "SELECT * FROM t0, t1 WHERE t1.val;"); + } + + @Test + void testComplicated() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStrs = { + "SELECT STRING_AGG(v0.c2) FROM t0, v0 WHERE (CASE true WHEN (ABS(true) BETWEEN (v0.c0 LIKE NULL ESCAPE v0.c2) AND (DATE '1970-01-23' NOT IN (v0.c2))) THEN (0.07914839711718646 NOT BETWEEN '' AND ((v0.c0)OR(v0.c2))) WHEN v0.c1 THEN ((v0.c1)-(v0.c0)) WHEN t0.c1 THEN (TIMESTAMP '1969-12-29 20:22:33' IN (PI(), v0.c2, (v0.c1 BETWEEN '' AND v0.rowid))) WHEN v0.c1 THEN TIMESTAMP '1969-12-16 17:24:43' WHEN ((((v0.c1)-(t0.c0)))||(t0.c0)) THEN true ELSE ((0.279978719843174)/(((v0.c1)>(DATE '1969-12-19')))) END ) GROUP BY ((DATE '1970-01-24') IS NULL), t0.c1, (CASE (v0.c1 LIKE ((0.9833120083624495)SIMILAR TO(t0.rowid)) ESCAPE CEIL(TIMESTAMP '1970-01-11 16:38:26')) WHEN t0.rowid THEN 0.27742217994251717 ELSE ((v0.c0) IS NOT NULL) END );" }; + env.setInitialStatementsFromStrings(List.of(queriesStrs)); + Function>, Boolean> condition = statements -> { + String queriesString = TestEnvironment.getQueriesString(statements); + try { + CCJSqlParserUtil.parse(queriesString); + } catch (JSQLParserException e) { + return false; + } + return queriesString.toUpperCase().contains("CASE"); + }; + env.setBugInducingCondition(condition); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertTrue(condition.apply(reducedResult)); + } + + @Test + void testSimplifyingConstantStringValue() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStrs = { "SELECT * FROM t0 WHERE v LIKE '[vQ3㭫oQ';" }; + env.setInitialStatementsFromStrings(List.of(queriesStrs)); + env.setBugInducingCondition(statements -> { + try { + for (Query s : statements) { + CCJSqlParserUtil.parse(s.getQueryString()); + } + } catch (JSQLParserException e) { + return false; + } + + String queriesString = TestEnvironment.getQueriesString(statements); + return queriesString.contains("LIKE"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals("SELECT * FROM t0 WHERE v LIKE '_';", TestEnvironment.getQueriesString(reducedResult)); + } + + @Test + void testSimplifyingConstantLongValue() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStrs = { "SELECT * FROM t0 where t0.v = 314598267;" }; + env.setInitialStatementsFromStrings(List.of(queriesStrs)); + env.setBugInducingCondition(statements -> { + try { + for (Query s : statements) { + CCJSqlParserUtil.parse(s.getQueryString()); + } + } catch (JSQLParserException e) { + return false; + } + + String queriesString = TestEnvironment.getQueriesString(statements); + return queriesString.contains("t0.v = "); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals("SELECT * FROM t0 WHERE t0.v = 0;", TestEnvironment.getQueriesString(reducedResult)); + } + + @Test + void testSubSelects() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStrs = { + "SELECT AVG(c0) FROM (SELECT SUM(c1) AS c0 FROM t1 GROUP BY c2 LIMIT 32 OFFSET 128) AS t1;" }; + env.setInitialStatementsFromStrings(List.of(queriesStrs)); + env.setBugInducingCondition(statements -> { + String queriesString = TestEnvironment.getQueriesString(statements); + try { + for (Query s : statements) { + CCJSqlParserUtil.parse(s.getQueryString()); + } + } catch (JSQLParserException e) { + return false; + } + return queriesString.contains("AVG"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals("SELECT AVG(c0) FROM (SELECT SUM(c1) AS c0 FROM t1) AS t1;", + TestEnvironment.getQueriesString(reducedResult)); + } + + @Test + void testInsert() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStrs = { + "INSERT INTO t1(c2, c0) VALUES (1508438260, 2929), (1508438260, TIMESTAMP '1969-12-26 01:57:21'), (0.5347171705591047, 398662142);" }; + env.setInitialStatementsFromStrings(List.of(queriesStrs)); + env.setBugInducingCondition(statements -> { + String queriesString = TestEnvironment.getQueriesString(statements); + try { + CCJSqlParserUtil.parse(queriesString); + } catch (JSQLParserException e) { + return false; + } + return queriesString.contains("(0.5347171705591047, 398662142)"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals("INSERT INTO t1 (c2, c0) VALUES (0.5347171705591047, 398662142);", + TestEnvironment.getQueriesString(reducedResult)); + } + + @Test + void testWithSelect() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + + String[] queriesStrs = { + "WITH cte1 AS (SELECT a, b FROM table1 where a < b), cte2 AS (SELECT c, d FROM table2 where c = d) SELECT b, d FROM cte1 JOIN cte2 WHERE cte1.a = cte2.c;" }; + env.setInitialStatementsFromStrings(List.of(queriesStrs)); + env.setBugInducingCondition(statements -> { + String queriesString = TestEnvironment.getQueriesString(statements); + try { + CCJSqlParserUtil.parse(queriesString); + } catch (JSQLParserException e) { + return false; + } + return queriesString.contains("table1"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals("WITH cte1 AS (SELECT a FROM table1) SELECT b FROM cte1;", + TestEnvironment.getQueriesString(reducedResult)); + } + + @Test + void testRoundDouble() throws Exception { + TestEnvironment env = TestEnvironment.getASTBasedReducerEnv(); + String[] queriesStrs = { "SELECT * FROM t0 WHERE (2.1427572639 IS NULL);" }; + env.setInitialStatementsFromStrings(List.of(queriesStrs)); + env.setBugInducingCondition(statements -> { + String queriesString = TestEnvironment.getQueriesString(statements); + try { + CCJSqlParserUtil.parse(queriesString); + } catch (JSQLParserException e) { + return false; + } + return queriesString.contains("WHERE"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals("SELECT * FROM t0 WHERE 2.143 IS NULL;", TestEnvironment.getQueriesString(reducedResult)); + } + +} diff --git a/test/sqlancer/reducer/TestEnvironment.java b/test/sqlancer/reducer/TestEnvironment.java new file mode 100644 index 000000000..7c1155490 --- /dev/null +++ b/test/sqlancer/reducer/TestEnvironment.java @@ -0,0 +1,139 @@ +package sqlancer.reducer; + +import sqlancer.*; +import sqlancer.common.query.Query; +import sqlancer.reducer.VirtualDB.VirtualDBGlobalState; +import sqlancer.reducer.VirtualDB.VirtualDBProvider; +import sqlancer.reducer.VirtualDB.VirtualDBQuery; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; +import java.util.ServiceLoader; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * TODO: Make Connection a generic type OR Fake a conn QUERY AND CONNECTION BOTH ARE FAKE. FAKE QUERY sub class + */ +public class TestEnvironment { + private final String databaseName = "virtual_db"; + private final MainOptions options = new MainOptions(); + private VirtualDBProvider provider = null; + private VirtualDBGlobalState state, newGlobalState; + + private Reducer reducer = null; + + enum ReducerType { + USING_STATEMENT_REDUCER, USING_AST_BASED_REDUCER + }; + + private TestEnvironment(ReducerType type) throws Exception { + setUpTestingEnvironment(); + if (type == ReducerType.USING_STATEMENT_REDUCER) { + reducer = new StatementReducer<>(provider); + } else if (type == ReducerType.USING_AST_BASED_REDUCER) { + reducer = new ASTBasedReducer<>(provider); + } + } + + public static TestEnvironment getStatementReducerEnv() throws Exception { + return new TestEnvironment(ReducerType.USING_STATEMENT_REDUCER); + } + + public static TestEnvironment getASTBasedReducerEnv() throws Exception { + return new TestEnvironment(ReducerType.USING_AST_BASED_REDUCER); + } + + /** + * @param queries: + * List of Query + * + * @return String of queries that appended together with '\n' separated (no '\n' at the last line) + */ + public static String getQueriesString(List> queries) { + return queries.stream().map(Query::getQueryString).collect(Collectors.joining("\n")); + } + + private VirtualDBGlobalState createGlobalState() { + try { + return provider.getGlobalStateClass().getDeclaredConstructor().newInstance(); + } catch (Exception e) { + throw new AssertionError(e); + } + } + + @SuppressWarnings("rawtypes") + private void initVirtualDBProvider() { + try { + ServiceLoader loader = ServiceLoader.load(DatabaseProvider.class); + for (DatabaseProvider provider : loader) { + if (provider.getDBMSName().equals(databaseName)) { + this.provider = (VirtualDBProvider) provider; + break; + } + } + if (provider == null) { + throw new AssertionError("testing provider not registered"); + } + } catch (Exception e) { + throw new AssertionError(e); + } + + } + + private void setUpTestingEnvironment() throws Exception { + initVirtualDBProvider(); + state = createGlobalState(); + StateToReproduce stateToReproduce = provider.getStateToReproduce(databaseName); + + state.setState(stateToReproduce); + state.setDatabaseName(databaseName); + // A really hacky way to enable reducer... + Field field = options.getClass().getDeclaredField("useReducer"); + field.setAccessible(true); + field.set(options, true); + state.setMainOptions(options); + + // Main.StateLogger logger = new Main.StateLogger(databaseName, provider, options); + // state.setStateLogger(logger); + + try (SQLConnection con = provider.createDatabase(state)) { + state.setConnection(con); + newGlobalState = createGlobalState(); + Main.StateLogger newLogger = new Main.StateLogger(databaseName, provider, options); + newGlobalState.setStateLogger(newLogger); + state.setStateLogger(newLogger); + newGlobalState.setState(stateToReproduce); + newGlobalState.setDatabaseName(databaseName); + newGlobalState.setMainOptions(options); + } + } + + public void setInitialStatementsFromStrings(List statements) { + List> queries = new ArrayList<>(); + for (String s : statements) { + queries.add(new VirtualDBQuery(s)); + } + state.getState().setStatements(queries); + } + + public void setBugInducingCondition(Function>, Boolean> bugInducingCondition) { + state.setBugInducingCondition(bugInducingCondition); + newGlobalState.setBugInducingCondition(bugInducingCondition); + } + + public void runReduce() throws Exception { + + Reproducer reproducer = provider.generateAndTestDatabase(newGlobalState); + reducer.reduce(state, reproducer, newGlobalState); + } + + public List> getReducedStatements() { + return newGlobalState.getState().getStatements(); + } + + public List> getInitialStatements() { + return state.getState().getStatements(); + } +} diff --git a/test/sqlancer/reducer/TestStatementReducer.java b/test/sqlancer/reducer/TestStatementReducer.java new file mode 100644 index 000000000..d9e364c84 --- /dev/null +++ b/test/sqlancer/reducer/TestStatementReducer.java @@ -0,0 +1,82 @@ +package sqlancer.reducer; + +import org.junit.jupiter.api.Test; +import sqlancer.Main; +import sqlancer.common.query.Query; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TestStatementReducer { + + @Test + void testSimple() throws Exception { + TestEnvironment env = TestEnvironment.getStatementReducerEnv(); + + String[] queriesStr = { "CREATE TABLE FAKE_TABLE;", "SELECT * FROM FAKE_TABLE;", "EXIT", }; + env.setInitialStatementsFromStrings(List.of(queriesStr)); + env.setBugInducingCondition(statements -> { + String queriesString = TestEnvironment.getQueriesString(statements); + return queriesString.contains("SELECT"); + }); + env.runReduce(); + List> reducedResult = env.getReducedStatements(); + assertEquals(1, reducedResult.size()); + assertEquals("SELECT * FROM FAKE_TABLE;", reducedResult.get(0).toString()); + + } + + @Test + void testDeltaDebugging() throws Exception { + TestEnvironment env = TestEnvironment.getStatementReducerEnv(); + List fakeStatements = new ArrayList<>(); + for (int i = 0; i < 10000; i++) { + String statement = "Statement_" + i + ";"; + fakeStatements.add(statement); + } + + env.setInitialStatementsFromStrings(fakeStatements); + env.setBugInducingCondition(statements -> { + String queries = TestEnvironment.getQueriesString(statements); + return queries.contains("Statement_29;"); + }); + + env.runReduce(); + List> reducedQueries = env.getReducedStatements(); + String queriesString = TestEnvironment.getQueriesString(reducedQueries); + assertEquals(queriesString, "Statement_29;"); + } + + @Test + void testDeltaDebuggingWithStatementsCombination() throws Exception { + TestEnvironment env = TestEnvironment.getStatementReducerEnv(); + List fakeStatements = new ArrayList<>(); + + String pattern = "(.*\\n)*(Statement_2;)\\n(.*\\n)*(Statement_318);\\n(.*\\n)*(Statement_990;)(.*\\n)*.*"; + for (int i = 0; i < 1000; i++) { + String statement = "Statement_" + i + ";"; + fakeStatements.add(statement); + } + + env.setInitialStatementsFromStrings(fakeStatements); + env.setBugInducingCondition(queryList -> { + String queries = TestEnvironment.getQueriesString(queryList); + return Pattern.matches(pattern, queries); + }); + + env.runReduce(); + List> reducedQueries = env.getReducedStatements(); + String queriesString = TestEnvironment.getQueriesString(reducedQueries); + assertEquals(queriesString, "Statement_2;\nStatement_318;\nStatement_990;"); + } + + @Test + void testSQLite3WithStatementReducer() { + Main.executeMain(new String[] { "--random-seed", "0", "--use-reducer", "--timeout-seconds", "60", + "--num-threads", "4", "sqlite3", "--oracle", "NoREC" }); + } + +} diff --git a/test/sqlancer/reducer/VirtualDB/VirtualDBConnection.java b/test/sqlancer/reducer/VirtualDB/VirtualDBConnection.java new file mode 100644 index 000000000..6a562a90d --- /dev/null +++ b/test/sqlancer/reducer/VirtualDB/VirtualDBConnection.java @@ -0,0 +1,18 @@ +package sqlancer.reducer.VirtualDB; + +import sqlancer.SQLConnection; + +import java.sql.Connection; +import java.sql.SQLException; + +public class VirtualDBConnection extends SQLConnection { + + public VirtualDBConnection(Connection connection) { + super(connection); + } + + @Override + public void close() throws SQLException { + + } +} diff --git a/test/sqlancer/reducer/VirtualDB/VirtualDBErrors.java b/test/sqlancer/reducer/VirtualDB/VirtualDBErrors.java new file mode 100644 index 000000000..64c783c11 --- /dev/null +++ b/test/sqlancer/reducer/VirtualDB/VirtualDBErrors.java @@ -0,0 +1,12 @@ +package sqlancer.reducer.VirtualDB; + +import sqlancer.common.query.ExpectedErrors; + +public final class VirtualDBErrors { + public VirtualDBErrors() { + } + + public static void addErrors(ExpectedErrors errors) { + errors.add("Default error"); + } +} diff --git a/test/sqlancer/reducer/VirtualDB/VirtualDBGlobalState.java b/test/sqlancer/reducer/VirtualDB/VirtualDBGlobalState.java new file mode 100644 index 000000000..a0548495c --- /dev/null +++ b/test/sqlancer/reducer/VirtualDB/VirtualDBGlobalState.java @@ -0,0 +1,56 @@ +package sqlancer.reducer.VirtualDB; + +import sqlancer.SQLConnection; +import sqlancer.SQLGlobalState; +import sqlancer.common.query.Query; + +import java.util.List; +import java.util.function.Function; + +@SuppressWarnings("all") +public class VirtualDBGlobalState extends SQLGlobalState { + + private SQLConnection virtualConn = new SQLConnection(null); + private StringBuilder queriesStringBuilder = new StringBuilder(); + private Function>, Boolean> bugInducingCondition = null; + + public Function>, Boolean> getBugInducingCondition() { + return bugInducingCondition; + } + + public void setBugInducingCondition(Function>, Boolean> condition) { + bugInducingCondition = (condition); + } + + @Override + protected VirtualDBSchema readSchema() throws Exception { + return null; + } + + @Override + public SQLConnection getConnection() { + // It's a fake engine, so the connection would not be available :) + return virtualConn; + } + + @Override + public void setConnection(SQLConnection con) { + // A fake connection could also not be closed. + // So nothing would be done here. + // And reset the query String (Seems needless) + // queriesStringBuilder = new StringBuilder(); + } + + // public String getCurrentQueriesString() { + // return queriesStringBuilder.toString(); + // } + + @Override + public boolean executeStatement(Query q, String... fills) throws Exception { + if (queriesStringBuilder.length() != 0) { + queriesStringBuilder.append("\n"); + } + queriesStringBuilder.append(q.getQueryString()); + return true; + } +} diff --git a/test/sqlancer/reducer/VirtualDB/VirtualDBOptions.java b/test/sqlancer/reducer/VirtualDB/VirtualDBOptions.java new file mode 100644 index 000000000..72cbe918c --- /dev/null +++ b/test/sqlancer/reducer/VirtualDB/VirtualDBOptions.java @@ -0,0 +1,29 @@ +package sqlancer.reducer.VirtualDB; + +import com.beust.jcommander.Parameters; +import sqlancer.DBMSSpecificOptions; +import sqlancer.OracleFactory; +import sqlancer.common.oracle.TestOracle; +import sqlancer.reducer.VirtualDB.VirtualDBOptions.VirtualDBFactory; + +import java.util.ArrayList; +import java.util.List; + +@Parameters(separators = "=", commandDescription = "VirtualDB (default port: " + "-1" + ", default host: " + "127.0.0.1" + + ")") +public class VirtualDBOptions implements DBMSSpecificOptions { + + List factories = new ArrayList<>(); + + @Override + public List getTestOracleFactory() { + return factories; + } + + public static class VirtualDBFactory implements OracleFactory { + @Override + public TestOracle create(VirtualDBGlobalState globalState) throws Exception { + return null; + } + } +} diff --git a/test/sqlancer/reducer/VirtualDB/VirtualDBProvider.java b/test/sqlancer/reducer/VirtualDB/VirtualDBProvider.java new file mode 100644 index 000000000..697e5c60e --- /dev/null +++ b/test/sqlancer/reducer/VirtualDB/VirtualDBProvider.java @@ -0,0 +1,54 @@ +package sqlancer.reducer.VirtualDB; + +import com.google.auto.service.AutoService; +import sqlancer.DatabaseProvider; +import sqlancer.Reproducer; +import sqlancer.SQLConnection; +import sqlancer.SQLProviderAdapter; + +@AutoService(DatabaseProvider.class) +public class VirtualDBProvider extends SQLProviderAdapter { + + private Reproducer reproducerForTesting; + + public VirtualDBProvider() { + super(VirtualDBGlobalState.class, VirtualDBOptions.class); + } + + @Override + public SQLConnection createDatabase(VirtualDBGlobalState globalState) throws Exception { + return new VirtualDBConnection(null); + } + + @Override + public String getDBMSName() { + return "virtual_db"; + } + + @Override + public void generateDatabase(VirtualDBGlobalState globalState) throws Exception { + + } + + @Override + public Reproducer generateAndTestDatabase(VirtualDBGlobalState globalState) throws Exception { + return state -> { + if (globalState.getBugInducingCondition() == null) + return false; + return globalState.getBugInducingCondition().apply(globalState.getState().getStatements()); + }; + } + + @Override + public Class getGlobalStateClass() { + return super.getGlobalStateClass(); + } + + public Reproducer getReproducerForTesting() { + return reproducerForTesting; + } + + public void setReproducerForTesting(Reproducer reproducer) { + this.reproducerForTesting = reproducer; + } +} diff --git a/test/sqlancer/reducer/VirtualDB/VirtualDBQuery.java b/test/sqlancer/reducer/VirtualDB/VirtualDBQuery.java new file mode 100644 index 000000000..9767723e0 --- /dev/null +++ b/test/sqlancer/reducer/VirtualDB/VirtualDBQuery.java @@ -0,0 +1,31 @@ +package sqlancer.reducer.VirtualDB; + +import sqlancer.GlobalState; +import sqlancer.SQLConnection; +import sqlancer.common.query.SQLQueryAdapter; + +import java.sql.SQLException; + +public class VirtualDBQuery extends SQLQueryAdapter { + private static final long serialVersionUID = 1L; + + public VirtualDBQuery(String query) { + // Since the base class must check the format + // We judge if the statement could affect schema. A bit hacky tho. + super(query, (query.contains("CREATE TABLE") && !query.startsWith("EXPLAIN"))); + } + + public VirtualDBQuery(String query, boolean couldAffectSchema) { + super(query, couldAffectSchema); + } + + @Override + public > boolean execute(G globalState, String... fills) + throws SQLException { + try { + return globalState.executeStatement(this, fills); + } catch (Exception e) { + throw new SQLException(e); + } + } +} diff --git a/test/sqlancer/reducer/VirtualDB/VirtualDBSchema.java b/test/sqlancer/reducer/VirtualDB/VirtualDBSchema.java new file mode 100644 index 000000000..9ea5b422e --- /dev/null +++ b/test/sqlancer/reducer/VirtualDB/VirtualDBSchema.java @@ -0,0 +1,46 @@ +package sqlancer.reducer.VirtualDB; + +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTable; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.TableIndex; +import sqlancer.reducer.VirtualDB.VirtualDBSchema.VirtualDBTable; + +import java.util.List; + +public class VirtualDBSchema extends AbstractSchema { + + public VirtualDBSchema(List databaseTables) { + super(databaseTables); + } + + public static class VirtualDBTable extends AbstractTable { + protected VirtualDBTable(String name, List columns, List indexes, + boolean isView) { + super(name, columns, indexes, isView); + } + + @Override + public long getNrRows(VirtualDBGlobalState globalState) { + return 0; + } + } + + public static class VirtualDBIndex extends TableIndex { + + protected VirtualDBIndex(String indexName) { + super(indexName); + } + } + + public static class VirtualDBDataType { + + } + + public static class VirtualDBColumn extends AbstractTableColumn { + + public VirtualDBColumn(String name, VirtualDBTable table, VirtualDBDataType type) { + super(name, table, type); + } + } +}