This is page 45 of 45. Use http://codebase.md/googleapis/genai-toolbox?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .ci │ ├── continuous.release.cloudbuild.yaml │ ├── generate_release_table.sh │ ├── integration.cloudbuild.yaml │ ├── quickstart_test │ │ ├── go.integration.cloudbuild.yaml │ │ ├── js.integration.cloudbuild.yaml │ │ ├── py.integration.cloudbuild.yaml │ │ ├── run_go_tests.sh │ │ ├── run_js_tests.sh │ │ ├── run_py_tests.sh │ │ └── setup_hotels_sample.sql │ ├── test_with_coverage.sh │ └── versioned.release.cloudbuild.yaml ├── .github │ ├── auto-label.yaml │ ├── blunderbuss.yml │ ├── CODEOWNERS │ ├── header-checker-lint.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ ├── config.yml │ │ ├── feature_request.yml │ │ └── question.yml │ ├── label-sync.yml │ ├── labels.yaml │ ├── PULL_REQUEST_TEMPLATE.md │ ├── release-please.yml │ ├── renovate.json5 │ ├── sync-repo-settings.yaml │ └── workflows │ ├── cloud_build_failure_reporter.yml │ ├── deploy_dev_docs.yaml │ ├── deploy_previous_version_docs.yaml │ ├── deploy_versioned_docs.yaml │ ├── docs_deploy.yaml │ ├── docs_preview_clean.yaml │ ├── docs_preview_deploy.yaml │ ├── lint.yaml │ ├── schedule_reporter.yml │ ├── sync-labels.yaml │ └── tests.yaml ├── .gitignore ├── .gitmodules ├── .golangci.yaml ├── .hugo │ ├── archetypes │ │ └── default.md │ ├── assets │ │ ├── icons │ │ │ └── logo.svg │ │ └── scss │ │ ├── _styles_project.scss │ │ └── _variables_project.scss │ ├── go.mod │ ├── go.sum │ ├── hugo.toml │ ├── layouts │ │ ├── _default │ │ │ └── home.releases.releases │ │ ├── index.llms-full.txt │ │ ├── index.llms.txt │ │ ├── partials │ │ │ ├── hooks │ │ │ │ └── head-end.html │ │ │ ├── navbar-version-selector.html │ │ │ ├── page-meta-links.html │ │ │ └── td │ │ │ └── render-heading.html │ │ ├── robot.txt │ │ └── shortcodes │ │ ├── include.html │ │ ├── ipynb.html │ │ └── regionInclude.html │ ├── package-lock.json │ ├── package.json │ └── static │ ├── favicons │ │ ├── android-chrome-192x192.png │ │ ├── android-chrome-512x512.png │ │ ├── apple-touch-icon.png │ │ ├── favicon-16x16.png │ │ ├── favicon-32x32.png │ │ └── favicon.ico │ └── js │ └── w3.js ├── CHANGELOG.md ├── cmd │ ├── options_test.go │ ├── options.go │ ├── root_test.go │ ├── root.go │ └── version.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DEVELOPER.md ├── Dockerfile ├── docs │ └── en │ ├── _index.md │ ├── about │ │ ├── _index.md │ │ └── faq.md │ ├── concepts │ │ ├── _index.md │ │ └── telemetry │ │ ├── index.md │ │ ├── telemetry_flow.png │ │ └── telemetry_traces.png │ ├── getting-started │ │ ├── _index.md │ │ ├── colab_quickstart.ipynb │ │ ├── configure.md │ │ ├── introduction │ │ │ ├── _index.md │ │ │ └── architecture.png │ │ ├── local_quickstart_go.md │ │ ├── local_quickstart_js.md │ │ ├── local_quickstart.md │ │ ├── mcp_quickstart │ │ │ ├── _index.md │ │ │ ├── inspector_tools.png │ │ │ └── inspector.png │ │ └── quickstart │ │ ├── go │ │ │ ├── genAI │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ ├── genkit │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ ├── langchain │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ ├── openAI │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ └── quickstart_test.go │ │ ├── golden.txt │ │ ├── js │ │ │ ├── genAI │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ ├── genkit │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ ├── langchain │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ ├── llamaindex │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ └── quickstart.test.js │ │ ├── python │ │ │ ├── __init__.py │ │ │ ├── adk │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ ├── core │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ ├── langchain │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ ├── llamaindex │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ └── quickstart_test.py │ │ └── shared │ │ ├── cloud_setup.md │ │ ├── configure_toolbox.md │ │ └── database_setup.md │ ├── how-to │ │ ├── _index.md │ │ ├── connect_via_geminicli.md │ │ ├── connect_via_mcp.md │ │ ├── connect-ide │ │ │ ├── _index.md │ │ │ ├── alloydb_pg_admin_mcp.md │ │ │ ├── alloydb_pg_mcp.md │ │ │ ├── bigquery_mcp.md │ │ │ ├── cloud_sql_mssql_admin_mcp.md │ │ │ ├── cloud_sql_mssql_mcp.md │ │ │ ├── cloud_sql_mysql_admin_mcp.md │ │ │ ├── cloud_sql_mysql_mcp.md │ │ │ ├── cloud_sql_pg_admin_mcp.md │ │ │ ├── cloud_sql_pg_mcp.md │ │ │ ├── firestore_mcp.md │ │ │ ├── looker_mcp.md │ │ │ ├── mssql_mcp.md │ │ │ ├── mysql_mcp.md │ │ │ ├── neo4j_mcp.md │ │ │ ├── postgres_mcp.md │ │ │ ├── spanner_mcp.md │ │ │ └── sqlite_mcp.md │ │ ├── deploy_docker.md │ │ ├── deploy_gke.md │ │ ├── deploy_toolbox.md │ │ ├── export_telemetry.md │ │ └── toolbox-ui │ │ ├── edit-headers.gif │ │ ├── edit-headers.png │ │ ├── index.md │ │ ├── optional-param-checked.png │ │ ├── optional-param-unchecked.png │ │ ├── run-tool.gif │ │ ├── tools.png │ │ └── toolsets.png │ ├── reference │ │ ├── _index.md │ │ ├── cli.md │ │ └── prebuilt-tools.md │ ├── resources │ │ ├── _index.md │ │ ├── authServices │ │ │ ├── _index.md │ │ │ └── google.md │ │ ├── sources │ │ │ ├── _index.md │ │ │ ├── alloydb-admin.md │ │ │ ├── alloydb-pg.md │ │ │ ├── bigquery.md │ │ │ ├── bigtable.md │ │ │ ├── cassandra.md │ │ │ ├── clickhouse.md │ │ │ ├── cloud-monitoring.md │ │ │ ├── cloud-sql-admin.md │ │ │ ├── cloud-sql-mssql.md │ │ │ ├── cloud-sql-mysql.md │ │ │ ├── cloud-sql-pg.md │ │ │ ├── couchbase.md │ │ │ ├── dataplex.md │ │ │ ├── dgraph.md │ │ │ ├── firebird.md │ │ │ ├── firestore.md │ │ │ ├── http.md │ │ │ ├── looker.md │ │ │ ├── mongodb.md │ │ │ ├── mssql.md │ │ │ ├── mysql.md │ │ │ ├── neo4j.md │ │ │ ├── oceanbase.md │ │ │ ├── oracle.md │ │ │ ├── postgres.md │ │ │ ├── redis.md │ │ │ ├── spanner.md │ │ │ ├── sqlite.md │ │ │ ├── tidb.md │ │ │ ├── trino.md │ │ │ ├── valkey.md │ │ │ └── yugabytedb.md │ │ └── tools │ │ ├── _index.md │ │ ├── alloydb │ │ │ ├── _index.md │ │ │ ├── alloydb-create-cluster.md │ │ │ ├── alloydb-create-instance.md │ │ │ ├── alloydb-create-user.md │ │ │ ├── alloydb-get-cluster.md │ │ │ ├── alloydb-get-instance.md │ │ │ ├── alloydb-get-user.md │ │ │ ├── alloydb-list-clusters.md │ │ │ ├── alloydb-list-instances.md │ │ │ ├── alloydb-list-users.md │ │ │ └── alloydb-wait-for-operation.md │ │ ├── alloydbainl │ │ │ ├── _index.md │ │ │ └── alloydb-ai-nl.md │ │ ├── bigquery │ │ │ ├── _index.md │ │ │ ├── bigquery-analyze-contribution.md │ │ │ ├── bigquery-conversational-analytics.md │ │ │ ├── bigquery-execute-sql.md │ │ │ ├── bigquery-forecast.md │ │ │ ├── bigquery-get-dataset-info.md │ │ │ ├── bigquery-get-table-info.md │ │ │ ├── bigquery-list-dataset-ids.md │ │ │ ├── bigquery-list-table-ids.md │ │ │ ├── bigquery-search-catalog.md │ │ │ └── bigquery-sql.md │ │ ├── bigtable │ │ │ ├── _index.md │ │ │ └── bigtable-sql.md │ │ ├── cassandra │ │ │ ├── _index.md │ │ │ └── cassandra-cql.md │ │ ├── clickhouse │ │ │ ├── _index.md │ │ │ ├── clickhouse-execute-sql.md │ │ │ ├── clickhouse-list-databases.md │ │ │ ├── clickhouse-list-tables.md │ │ │ └── clickhouse-sql.md │ │ ├── cloudmonitoring │ │ │ ├── _index.md │ │ │ └── cloud-monitoring-query-prometheus.md │ │ ├── cloudsql │ │ │ ├── _index.md │ │ │ ├── cloudsqlcreatedatabase.md │ │ │ ├── cloudsqlcreateusers.md │ │ │ ├── cloudsqlgetinstances.md │ │ │ ├── cloudsqllistdatabases.md │ │ │ ├── cloudsqllistinstances.md │ │ │ ├── cloudsqlmssqlcreateinstance.md │ │ │ ├── cloudsqlmysqlcreateinstance.md │ │ │ ├── cloudsqlpgcreateinstances.md │ │ │ └── cloudsqlwaitforoperation.md │ │ ├── couchbase │ │ │ ├── _index.md │ │ │ └── couchbase-sql.md │ │ ├── dataform │ │ │ ├── _index.md │ │ │ └── dataform-compile-local.md │ │ ├── dataplex │ │ │ ├── _index.md │ │ │ ├── dataplex-lookup-entry.md │ │ │ ├── dataplex-search-aspect-types.md │ │ │ └── dataplex-search-entries.md │ │ ├── dgraph │ │ │ ├── _index.md │ │ │ └── dgraph-dql.md │ │ ├── firebird │ │ │ ├── _index.md │ │ │ ├── firebird-execute-sql.md │ │ │ └── firebird-sql.md │ │ ├── firestore │ │ │ ├── _index.md │ │ │ ├── firestore-add-documents.md │ │ │ ├── firestore-delete-documents.md │ │ │ ├── firestore-get-documents.md │ │ │ ├── firestore-get-rules.md │ │ │ ├── firestore-list-collections.md │ │ │ ├── firestore-query-collection.md │ │ │ ├── firestore-query.md │ │ │ ├── firestore-update-document.md │ │ │ └── firestore-validate-rules.md │ │ ├── http │ │ │ ├── _index.md │ │ │ └── http.md │ │ ├── looker │ │ │ ├── _index.md │ │ │ ├── looker-add-dashboard-element.md │ │ │ ├── looker-conversational-analytics.md │ │ │ ├── looker-get-dashboards.md │ │ │ ├── looker-get-dimensions.md │ │ │ ├── looker-get-explores.md │ │ │ ├── looker-get-filters.md │ │ │ ├── looker-get-looks.md │ │ │ ├── looker-get-measures.md │ │ │ ├── looker-get-models.md │ │ │ ├── looker-get-parameters.md │ │ │ ├── looker-health-analyze.md │ │ │ ├── looker-health-pulse.md │ │ │ ├── looker-health-vacuum.md │ │ │ ├── looker-make-dashboard.md │ │ │ ├── looker-make-look.md │ │ │ ├── looker-query-sql.md │ │ │ ├── looker-query-url.md │ │ │ ├── looker-query.md │ │ │ └── looker-run-look.md │ │ ├── mongodb │ │ │ ├── _index.md │ │ │ ├── mongodb-aggregate.md │ │ │ ├── mongodb-delete-many.md │ │ │ ├── mongodb-delete-one.md │ │ │ ├── mongodb-find-one.md │ │ │ ├── mongodb-find.md │ │ │ ├── mongodb-insert-many.md │ │ │ ├── mongodb-insert-one.md │ │ │ ├── mongodb-update-many.md │ │ │ └── mongodb-update-one.md │ │ ├── mssql │ │ │ ├── _index.md │ │ │ ├── mssql-execute-sql.md │ │ │ ├── mssql-list-tables.md │ │ │ └── mssql-sql.md │ │ ├── mysql │ │ │ ├── _index.md │ │ │ ├── mysql-execute-sql.md │ │ │ ├── mysql-list-active-queries.md │ │ │ ├── mysql-list-table-fragmentation.md │ │ │ ├── mysql-list-tables-missing-unique-indexes.md │ │ │ ├── mysql-list-tables.md │ │ │ └── mysql-sql.md │ │ ├── neo4j │ │ │ ├── _index.md │ │ │ ├── neo4j-cypher.md │ │ │ ├── neo4j-execute-cypher.md │ │ │ └── neo4j-schema.md │ │ ├── oceanbase │ │ │ ├── _index.md │ │ │ ├── oceanbase-execute-sql.md │ │ │ └── oceanbase-sql.md │ │ ├── oracle │ │ │ ├── _index.md │ │ │ ├── oracle-execute-sql.md │ │ │ └── oracle-sql.md │ │ ├── postgres │ │ │ ├── _index.md │ │ │ ├── postgres-execute-sql.md │ │ │ ├── postgres-list-active-queries.md │ │ │ ├── postgres-list-available-extensions.md │ │ │ ├── postgres-list-installed-extensions.md │ │ │ ├── postgres-list-tables.md │ │ │ └── postgres-sql.md │ │ ├── redis │ │ │ ├── _index.md │ │ │ └── redis.md │ │ ├── spanner │ │ │ ├── _index.md │ │ │ ├── spanner-execute-sql.md │ │ │ ├── spanner-list-tables.md │ │ │ └── spanner-sql.md │ │ ├── sqlite │ │ │ ├── _index.md │ │ │ ├── sqlite-execute-sql.md │ │ │ └── sqlite-sql.md │ │ ├── tidb │ │ │ ├── _index.md │ │ │ ├── tidb-execute-sql.md │ │ │ └── tidb-sql.md │ │ ├── trino │ │ │ ├── _index.md │ │ │ ├── trino-execute-sql.md │ │ │ └── trino-sql.md │ │ ├── utility │ │ │ ├── _index.md │ │ │ └── wait.md │ │ ├── valkey │ │ │ ├── _index.md │ │ │ └── valkey.md │ │ └── yuagbytedb │ │ ├── _index.md │ │ └── yugabytedb-sql.md │ ├── samples │ │ ├── _index.md │ │ ├── alloydb │ │ │ ├── _index.md │ │ │ ├── ai-nl │ │ │ │ ├── alloydb_ai_nl.ipynb │ │ │ │ └── index.md │ │ │ └── mcp_quickstart.md │ │ ├── bigquery │ │ │ ├── _index.md │ │ │ ├── colab_quickstart_bigquery.ipynb │ │ │ ├── local_quickstart.md │ │ │ └── mcp_quickstart │ │ │ ├── _index.md │ │ │ ├── inspector_tools.png │ │ │ └── inspector.png │ │ └── looker │ │ ├── _index.md │ │ ├── looker_gemini_oauth │ │ │ ├── _index.md │ │ │ ├── authenticated.png │ │ │ ├── authorize.png │ │ │ └── registration.png │ │ ├── looker_gemini.md │ │ └── looker_mcp_inspector │ │ ├── _index.md │ │ ├── inspector_tools.png │ │ └── inspector.png │ └── sdks │ ├── _index.md │ ├── go-sdk.md │ ├── js-sdk.md │ └── python-sdk.md ├── go.mod ├── go.sum ├── internal │ ├── auth │ │ ├── auth.go │ │ └── google │ │ └── google.go │ ├── log │ │ ├── handler.go │ │ ├── log_test.go │ │ ├── log.go │ │ └── logger.go │ ├── prebuiltconfigs │ │ ├── prebuiltconfigs_test.go │ │ ├── prebuiltconfigs.go │ │ └── tools │ │ ├── alloydb-postgres-admin.yaml │ │ ├── alloydb-postgres-observability.yaml │ │ ├── alloydb-postgres.yaml │ │ ├── bigquery.yaml │ │ ├── clickhouse.yaml │ │ ├── cloud-sql-mssql-admin.yaml │ │ ├── cloud-sql-mssql-observability.yaml │ │ ├── cloud-sql-mssql.yaml │ │ ├── cloud-sql-mysql-admin.yaml │ │ ├── cloud-sql-mysql-observability.yaml │ │ ├── cloud-sql-mysql.yaml │ │ ├── cloud-sql-postgres-admin.yaml │ │ ├── cloud-sql-postgres-observability.yaml │ │ ├── cloud-sql-postgres.yaml │ │ ├── dataplex.yaml │ │ ├── firestore.yaml │ │ ├── looker-conversational-analytics.yaml │ │ ├── looker.yaml │ │ ├── mssql.yaml │ │ ├── mysql.yaml │ │ ├── neo4j.yaml │ │ ├── oceanbase.yaml │ │ ├── postgres.yaml │ │ ├── spanner-postgres.yaml │ │ ├── spanner.yaml │ │ └── sqlite.yaml │ ├── server │ │ ├── api_test.go │ │ ├── api.go │ │ ├── common_test.go │ │ ├── config.go │ │ ├── mcp │ │ │ ├── jsonrpc │ │ │ │ ├── jsonrpc_test.go │ │ │ │ └── jsonrpc.go │ │ │ ├── mcp.go │ │ │ ├── util │ │ │ │ └── lifecycle.go │ │ │ ├── v20241105 │ │ │ │ ├── method.go │ │ │ │ └── types.go │ │ │ ├── v20250326 │ │ │ │ ├── method.go │ │ │ │ └── types.go │ │ │ └── v20250618 │ │ │ ├── method.go │ │ │ └── types.go │ │ ├── mcp_test.go │ │ ├── mcp.go │ │ ├── server_test.go │ │ ├── server.go │ │ ├── static │ │ │ ├── assets │ │ │ │ └── mcptoolboxlogo.png │ │ │ ├── css │ │ │ │ └── style.css │ │ │ ├── index.html │ │ │ ├── js │ │ │ │ ├── auth.js │ │ │ │ ├── loadTools.js │ │ │ │ ├── mainContent.js │ │ │ │ ├── navbar.js │ │ │ │ ├── runTool.js │ │ │ │ ├── toolDisplay.js │ │ │ │ ├── tools.js │ │ │ │ └── toolsets.js │ │ │ ├── tools.html │ │ │ └── toolsets.html │ │ ├── web_test.go │ │ └── web.go │ ├── sources │ │ ├── alloydbadmin │ │ │ ├── alloydbadmin_test.go │ │ │ └── alloydbadmin.go │ │ ├── alloydbpg │ │ │ ├── alloydb_pg_test.go │ │ │ └── alloydb_pg.go │ │ ├── bigquery │ │ │ ├── bigquery_test.go │ │ │ └── bigquery.go │ │ ├── bigtable │ │ │ ├── bigtable_test.go │ │ │ └── bigtable.go │ │ ├── cassandra │ │ │ ├── cassandra_test.go │ │ │ └── cassandra.go │ │ ├── clickhouse │ │ │ ├── clickhouse_test.go │ │ │ └── clickhouse.go │ │ ├── cloudmonitoring │ │ │ ├── cloud_monitoring_test.go │ │ │ └── cloud_monitoring.go │ │ ├── cloudsqladmin │ │ │ ├── cloud_sql_admin_test.go │ │ │ └── cloud_sql_admin.go │ │ ├── cloudsqlmssql │ │ │ ├── cloud_sql_mssql_test.go │ │ │ └── cloud_sql_mssql.go │ │ ├── cloudsqlmysql │ │ │ ├── cloud_sql_mysql_test.go │ │ │ └── cloud_sql_mysql.go │ │ ├── cloudsqlpg │ │ │ ├── cloud_sql_pg_test.go │ │ │ └── cloud_sql_pg.go │ │ ├── couchbase │ │ │ ├── couchbase_test.go │ │ │ └── couchbase.go │ │ ├── dataplex │ │ │ ├── dataplex_test.go │ │ │ └── dataplex.go │ │ ├── dgraph │ │ │ ├── dgraph_test.go │ │ │ └── dgraph.go │ │ ├── dialect.go │ │ ├── firebird │ │ │ ├── firebird_test.go │ │ │ └── firebird.go │ │ ├── firestore │ │ │ ├── firestore_test.go │ │ │ └── firestore.go │ │ ├── http │ │ │ ├── http_test.go │ │ │ └── http.go │ │ ├── ip_type.go │ │ ├── looker │ │ │ ├── looker_test.go │ │ │ └── looker.go │ │ ├── mongodb │ │ │ ├── mongodb_test.go │ │ │ └── mongodb.go │ │ ├── mssql │ │ │ ├── mssql_test.go │ │ │ └── mssql.go │ │ ├── mysql │ │ │ ├── mysql_test.go │ │ │ └── mysql.go │ │ ├── neo4j │ │ │ ├── neo4j_test.go │ │ │ └── neo4j.go │ │ ├── oceanbase │ │ │ ├── oceanbase_test.go │ │ │ └── oceanbase.go │ │ ├── oracle │ │ │ └── oracle.go │ │ ├── postgres │ │ │ ├── postgres_test.go │ │ │ └── postgres.go │ │ ├── redis │ │ │ ├── redis_test.go │ │ │ └── redis.go │ │ ├── sources.go │ │ ├── spanner │ │ │ ├── spanner_test.go │ │ │ └── spanner.go │ │ ├── sqlite │ │ │ ├── sqlite_test.go │ │ │ └── sqlite.go │ │ ├── tidb │ │ │ ├── tidb_test.go │ │ │ └── tidb.go │ │ ├── trino │ │ │ ├── trino_test.go │ │ │ └── trino.go │ │ ├── util.go │ │ ├── valkey │ │ │ ├── valkey_test.go │ │ │ └── valkey.go │ │ └── yugabytedb │ │ ├── yugabytedb_test.go │ │ └── yugabytedb.go │ ├── telemetry │ │ ├── instrumentation.go │ │ └── telemetry.go │ ├── testutils │ │ └── testutils.go │ ├── tools │ │ ├── alloydb │ │ │ ├── alloydbcreatecluster │ │ │ │ ├── alloydbcreatecluster_test.go │ │ │ │ └── alloydbcreatecluster.go │ │ │ ├── alloydbcreateinstance │ │ │ │ ├── alloydbcreateinstance_test.go │ │ │ │ └── alloydbcreateinstance.go │ │ │ ├── alloydbcreateuser │ │ │ │ ├── alloydbcreateuser_test.go │ │ │ │ └── alloydbcreateuser.go │ │ │ ├── alloydbgetcluster │ │ │ │ ├── alloydbgetcluster_test.go │ │ │ │ └── alloydbgetcluster.go │ │ │ ├── alloydbgetinstance │ │ │ │ ├── alloydbgetinstance_test.go │ │ │ │ └── alloydbgetinstance.go │ │ │ ├── alloydbgetuser │ │ │ │ ├── alloydbgetuser_test.go │ │ │ │ └── alloydbgetuser.go │ │ │ ├── alloydblistclusters │ │ │ │ ├── alloydblistclusters_test.go │ │ │ │ └── alloydblistclusters.go │ │ │ ├── alloydblistinstances │ │ │ │ ├── alloydblistinstances_test.go │ │ │ │ └── alloydblistinstances.go │ │ │ ├── alloydblistusers │ │ │ │ ├── alloydblistusers_test.go │ │ │ │ └── alloydblistusers.go │ │ │ └── alloydbwaitforoperation │ │ │ ├── alloydbwaitforoperation_test.go │ │ │ └── alloydbwaitforoperation.go │ │ ├── alloydbainl │ │ │ ├── alloydbainl_test.go │ │ │ └── alloydbainl.go │ │ ├── bigquery │ │ │ ├── bigqueryanalyzecontribution │ │ │ │ ├── bigqueryanalyzecontribution_test.go │ │ │ │ └── bigqueryanalyzecontribution.go │ │ │ ├── bigquerycommon │ │ │ │ ├── table_name_parser_test.go │ │ │ │ ├── table_name_parser.go │ │ │ │ └── util.go │ │ │ ├── bigqueryconversationalanalytics │ │ │ │ ├── bigqueryconversationalanalytics_test.go │ │ │ │ └── bigqueryconversationalanalytics.go │ │ │ ├── bigqueryexecutesql │ │ │ │ ├── bigqueryexecutesql_test.go │ │ │ │ └── bigqueryexecutesql.go │ │ │ ├── bigqueryforecast │ │ │ │ ├── bigqueryforecast_test.go │ │ │ │ └── bigqueryforecast.go │ │ │ ├── bigquerygetdatasetinfo │ │ │ │ ├── bigquerygetdatasetinfo_test.go │ │ │ │ └── bigquerygetdatasetinfo.go │ │ │ ├── bigquerygettableinfo │ │ │ │ ├── bigquerygettableinfo_test.go │ │ │ │ └── bigquerygettableinfo.go │ │ │ ├── bigquerylistdatasetids │ │ │ │ ├── bigquerylistdatasetids_test.go │ │ │ │ └── bigquerylistdatasetids.go │ │ │ ├── bigquerylisttableids │ │ │ │ ├── bigquerylisttableids_test.go │ │ │ │ └── bigquerylisttableids.go │ │ │ ├── bigquerysearchcatalog │ │ │ │ ├── bigquerysearchcatalog_test.go │ │ │ │ └── bigquerysearchcatalog.go │ │ │ └── bigquerysql │ │ │ ├── bigquerysql_test.go │ │ │ └── bigquerysql.go │ │ ├── bigtable │ │ │ ├── bigtable_test.go │ │ │ └── bigtable.go │ │ ├── cassandra │ │ │ └── cassandracql │ │ │ ├── cassandracql_test.go │ │ │ └── cassandracql.go │ │ ├── clickhouse │ │ │ ├── clickhouseexecutesql │ │ │ │ ├── clickhouseexecutesql_test.go │ │ │ │ └── clickhouseexecutesql.go │ │ │ ├── clickhouselistdatabases │ │ │ │ ├── clickhouselistdatabases_test.go │ │ │ │ └── clickhouselistdatabases.go │ │ │ ├── clickhouselisttables │ │ │ │ ├── clickhouselisttables_test.go │ │ │ │ └── clickhouselisttables.go │ │ │ └── clickhousesql │ │ │ ├── clickhousesql_test.go │ │ │ └── clickhousesql.go │ │ ├── cloudmonitoring │ │ │ ├── cloudmonitoring_test.go │ │ │ └── cloudmonitoring.go │ │ ├── cloudsql │ │ │ ├── cloudsqlcreatedatabase │ │ │ │ ├── cloudsqlcreatedatabase_test.go │ │ │ │ └── cloudsqlcreatedatabase.go │ │ │ ├── cloudsqlcreateusers │ │ │ │ ├── cloudsqlcreateusers_test.go │ │ │ │ └── cloudsqlcreateusers.go │ │ │ ├── cloudsqlgetinstances │ │ │ │ ├── cloudsqlgetinstances_test.go │ │ │ │ └── cloudsqlgetinstances.go │ │ │ ├── cloudsqllistdatabases │ │ │ │ ├── cloudsqllistdatabases_test.go │ │ │ │ └── cloudsqllistdatabases.go │ │ │ ├── cloudsqllistinstances │ │ │ │ ├── cloudsqllistinstances_test.go │ │ │ │ └── cloudsqllistinstances.go │ │ │ └── cloudsqlwaitforoperation │ │ │ ├── cloudsqlwaitforoperation_test.go │ │ │ └── cloudsqlwaitforoperation.go │ │ ├── cloudsqlmssql │ │ │ └── cloudsqlmssqlcreateinstance │ │ │ ├── cloudsqlmssqlcreateinstance_test.go │ │ │ └── cloudsqlmssqlcreateinstance.go │ │ ├── cloudsqlmysql │ │ │ └── cloudsqlmysqlcreateinstance │ │ │ ├── cloudsqlmysqlcreateinstance_test.go │ │ │ └── cloudsqlmysqlcreateinstance.go │ │ ├── cloudsqlpg │ │ │ └── cloudsqlpgcreateinstances │ │ │ ├── cloudsqlpgcreateinstances_test.go │ │ │ └── cloudsqlpgcreateinstances.go │ │ ├── common_test.go │ │ ├── common.go │ │ ├── couchbase │ │ │ ├── couchbase_test.go │ │ │ └── couchbase.go │ │ ├── dataform │ │ │ └── dataformcompilelocal │ │ │ ├── dataformcompilelocal_test.go │ │ │ └── dataformcompilelocal.go │ │ ├── dataplex │ │ │ ├── dataplexlookupentry │ │ │ │ ├── dataplexlookupentry_test.go │ │ │ │ └── dataplexlookupentry.go │ │ │ ├── dataplexsearchaspecttypes │ │ │ │ ├── dataplexsearchaspecttypes_test.go │ │ │ │ └── dataplexsearchaspecttypes.go │ │ │ └── dataplexsearchentries │ │ │ ├── dataplexsearchentries_test.go │ │ │ └── dataplexsearchentries.go │ │ ├── dgraph │ │ │ ├── dgraph_test.go │ │ │ └── dgraph.go │ │ ├── firebird │ │ │ ├── firebirdexecutesql │ │ │ │ ├── firebirdexecutesql_test.go │ │ │ │ └── firebirdexecutesql.go │ │ │ └── firebirdsql │ │ │ ├── firebirdsql_test.go │ │ │ └── firebirdsql.go │ │ ├── firestore │ │ │ ├── firestoreadddocuments │ │ │ │ ├── firestoreadddocuments_test.go │ │ │ │ └── firestoreadddocuments.go │ │ │ ├── firestoredeletedocuments │ │ │ │ ├── firestoredeletedocuments_test.go │ │ │ │ └── firestoredeletedocuments.go │ │ │ ├── firestoregetdocuments │ │ │ │ ├── firestoregetdocuments_test.go │ │ │ │ └── firestoregetdocuments.go │ │ │ ├── firestoregetrules │ │ │ │ ├── firestoregetrules_test.go │ │ │ │ └── firestoregetrules.go │ │ │ ├── firestorelistcollections │ │ │ │ ├── firestorelistcollections_test.go │ │ │ │ └── firestorelistcollections.go │ │ │ ├── firestorequery │ │ │ │ ├── firestorequery_test.go │ │ │ │ └── firestorequery.go │ │ │ ├── firestorequerycollection │ │ │ │ ├── firestorequerycollection_test.go │ │ │ │ └── firestorequerycollection.go │ │ │ ├── firestoreupdatedocument │ │ │ │ ├── firestoreupdatedocument_test.go │ │ │ │ └── firestoreupdatedocument.go │ │ │ ├── firestorevalidaterules │ │ │ │ ├── firestorevalidaterules_test.go │ │ │ │ └── firestorevalidaterules.go │ │ │ └── util │ │ │ ├── converter_test.go │ │ │ ├── converter.go │ │ │ ├── validator_test.go │ │ │ └── validator.go │ │ ├── http │ │ │ ├── http_test.go │ │ │ └── http.go │ │ ├── http_method.go │ │ ├── looker │ │ │ ├── lookeradddashboardelement │ │ │ │ ├── lookeradddashboardelement_test.go │ │ │ │ └── lookeradddashboardelement.go │ │ │ ├── lookercommon │ │ │ │ ├── lookercommon_test.go │ │ │ │ └── lookercommon.go │ │ │ ├── lookerconversationalanalytics │ │ │ │ ├── lookerconversationalanalytics_test.go │ │ │ │ └── lookerconversationalanalytics.go │ │ │ ├── lookergetdashboards │ │ │ │ ├── lookergetdashboards_test.go │ │ │ │ └── lookergetdashboards.go │ │ │ ├── lookergetdimensions │ │ │ │ ├── lookergetdimensions_test.go │ │ │ │ └── lookergetdimensions.go │ │ │ ├── lookergetexplores │ │ │ │ ├── lookergetexplores_test.go │ │ │ │ └── lookergetexplores.go │ │ │ ├── lookergetfilters │ │ │ │ ├── lookergetfilters_test.go │ │ │ │ └── lookergetfilters.go │ │ │ ├── lookergetlooks │ │ │ │ ├── lookergetlooks_test.go │ │ │ │ └── lookergetlooks.go │ │ │ ├── lookergetmeasures │ │ │ │ ├── lookergetmeasures_test.go │ │ │ │ └── lookergetmeasures.go │ │ │ ├── lookergetmodels │ │ │ │ ├── lookergetmodels_test.go │ │ │ │ └── lookergetmodels.go │ │ │ ├── lookergetparameters │ │ │ │ ├── lookergetparameters_test.go │ │ │ │ └── lookergetparameters.go │ │ │ ├── lookerhealthanalyze │ │ │ │ ├── lookerhealthanalyze_test.go │ │ │ │ └── lookerhealthanalyze.go │ │ │ ├── lookerhealthpulse │ │ │ │ ├── lookerhealthpulse_test.go │ │ │ │ └── lookerhealthpulse.go │ │ │ ├── lookerhealthvacuum │ │ │ │ ├── lookerhealthvacuum_test.go │ │ │ │ └── lookerhealthvacuum.go │ │ │ ├── lookermakedashboard │ │ │ │ ├── lookermakedashboard_test.go │ │ │ │ └── lookermakedashboard.go │ │ │ ├── lookermakelook │ │ │ │ ├── lookermakelook_test.go │ │ │ │ └── lookermakelook.go │ │ │ ├── lookerquery │ │ │ │ ├── lookerquery_test.go │ │ │ │ └── lookerquery.go │ │ │ ├── lookerquerysql │ │ │ │ ├── lookerquerysql_test.go │ │ │ │ └── lookerquerysql.go │ │ │ ├── lookerqueryurl │ │ │ │ ├── lookerqueryurl_test.go │ │ │ │ └── lookerqueryurl.go │ │ │ └── lookerrunlook │ │ │ ├── lookerrunlook_test.go │ │ │ └── lookerrunlook.go │ │ ├── mongodb │ │ │ ├── mongodbaggregate │ │ │ │ ├── mongodbaggregate_test.go │ │ │ │ └── mongodbaggregate.go │ │ │ ├── mongodbdeletemany │ │ │ │ ├── mongodbdeletemany_test.go │ │ │ │ └── mongodbdeletemany.go │ │ │ ├── mongodbdeleteone │ │ │ │ ├── mongodbdeleteone_test.go │ │ │ │ └── mongodbdeleteone.go │ │ │ ├── mongodbfind │ │ │ │ ├── mongodbfind_test.go │ │ │ │ └── mongodbfind.go │ │ │ ├── mongodbfindone │ │ │ │ ├── mongodbfindone_test.go │ │ │ │ └── mongodbfindone.go │ │ │ ├── mongodbinsertmany │ │ │ │ ├── mongodbinsertmany_test.go │ │ │ │ └── mongodbinsertmany.go │ │ │ ├── mongodbinsertone │ │ │ │ ├── mongodbinsertone_test.go │ │ │ │ └── mongodbinsertone.go │ │ │ ├── mongodbupdatemany │ │ │ │ ├── mongodbupdatemany_test.go │ │ │ │ └── mongodbupdatemany.go │ │ │ └── mongodbupdateone │ │ │ ├── mongodbupdateone_test.go │ │ │ └── mongodbupdateone.go │ │ ├── mssql │ │ │ ├── mssqlexecutesql │ │ │ │ ├── mssqlexecutesql_test.go │ │ │ │ └── mssqlexecutesql.go │ │ │ ├── mssqllisttables │ │ │ │ ├── mssqllisttables_test.go │ │ │ │ └── mssqllisttables.go │ │ │ └── mssqlsql │ │ │ ├── mssqlsql_test.go │ │ │ └── mssqlsql.go │ │ ├── mysql │ │ │ ├── mysqlcommon │ │ │ │ └── mysqlcommon.go │ │ │ ├── mysqlexecutesql │ │ │ │ ├── mysqlexecutesql_test.go │ │ │ │ └── mysqlexecutesql.go │ │ │ ├── mysqllistactivequeries │ │ │ │ ├── mysqllistactivequeries_test.go │ │ │ │ └── mysqllistactivequeries.go │ │ │ ├── mysqllisttablefragmentation │ │ │ │ ├── mysqllisttablefragmentation_test.go │ │ │ │ └── mysqllisttablefragmentation.go │ │ │ ├── mysqllisttables │ │ │ │ ├── mysqllisttables_test.go │ │ │ │ └── mysqllisttables.go │ │ │ ├── mysqllisttablesmissinguniqueindexes │ │ │ │ ├── mysqllisttablesmissinguniqueindexes_test.go │ │ │ │ └── mysqllisttablesmissinguniqueindexes.go │ │ │ └── mysqlsql │ │ │ ├── mysqlsql_test.go │ │ │ └── mysqlsql.go │ │ ├── neo4j │ │ │ ├── neo4jcypher │ │ │ │ ├── neo4jcypher_test.go │ │ │ │ └── neo4jcypher.go │ │ │ ├── neo4jexecutecypher │ │ │ │ ├── classifier │ │ │ │ │ ├── classifier_test.go │ │ │ │ │ └── classifier.go │ │ │ │ ├── neo4jexecutecypher_test.go │ │ │ │ └── neo4jexecutecypher.go │ │ │ └── neo4jschema │ │ │ ├── cache │ │ │ │ ├── cache_test.go │ │ │ │ └── cache.go │ │ │ ├── helpers │ │ │ │ ├── helpers_test.go │ │ │ │ └── helpers.go │ │ │ ├── neo4jschema_test.go │ │ │ ├── neo4jschema.go │ │ │ └── types │ │ │ └── types.go │ │ ├── oceanbase │ │ │ ├── oceanbaseexecutesql │ │ │ │ ├── oceanbaseexecutesql_test.go │ │ │ │ └── oceanbaseexecutesql.go │ │ │ └── oceanbasesql │ │ │ ├── oceanbasesql_test.go │ │ │ └── oceanbasesql.go │ │ ├── oracle │ │ │ ├── oracleexecutesql │ │ │ │ └── oracleexecutesql.go │ │ │ └── oraclesql │ │ │ └── oraclesql.go │ │ ├── parameters_test.go │ │ ├── parameters.go │ │ ├── postgres │ │ │ ├── postgresexecutesql │ │ │ │ ├── postgresexecutesql_test.go │ │ │ │ └── postgresexecutesql.go │ │ │ ├── postgreslistactivequeries │ │ │ │ ├── postgreslistactivequeries_test.go │ │ │ │ └── postgreslistactivequeries.go │ │ │ ├── postgreslistavailableextensions │ │ │ │ ├── postgreslistavailableextensions_test.go │ │ │ │ └── postgreslistavailableextensions.go │ │ │ ├── postgreslistinstalledextensions │ │ │ │ ├── postgreslistinstalledextensions_test.go │ │ │ │ └── postgreslistinstalledextensions.go │ │ │ ├── postgreslisttables │ │ │ │ ├── postgreslisttables_test.go │ │ │ │ └── postgreslisttables.go │ │ │ └── postgressql │ │ │ ├── postgressql_test.go │ │ │ └── postgressql.go │ │ ├── redis │ │ │ ├── redis_test.go │ │ │ └── redis.go │ │ ├── spanner │ │ │ ├── spannerexecutesql │ │ │ │ ├── spannerexecutesql_test.go │ │ │ │ └── spannerexecutesql.go │ │ │ ├── spannerlisttables │ │ │ │ ├── spannerlisttables_test.go │ │ │ │ └── spannerlisttables.go │ │ │ └── spannersql │ │ │ ├── spanner_test.go │ │ │ └── spannersql.go │ │ ├── sqlite │ │ │ ├── sqliteexecutesql │ │ │ │ ├── sqliteexecutesql_test.go │ │ │ │ └── sqliteexecutesql.go │ │ │ └── sqlitesql │ │ │ ├── sqlitesql_test.go │ │ │ └── sqlitesql.go │ │ ├── tidb │ │ │ ├── tidbexecutesql │ │ │ │ ├── tidbexecutesql_test.go │ │ │ │ └── tidbexecutesql.go │ │ │ └── tidbsql │ │ │ ├── tidbsql_test.go │ │ │ └── tidbsql.go │ │ ├── tools_test.go │ │ ├── tools.go │ │ ├── toolsets.go │ │ ├── trino │ │ │ ├── trinoexecutesql │ │ │ │ ├── trinoexecutesql_test.go │ │ │ │ └── trinoexecutesql.go │ │ │ └── trinosql │ │ │ ├── trinosql_test.go │ │ │ └── trinosql.go │ │ ├── utility │ │ │ └── wait │ │ │ ├── wait_test.go │ │ │ └── wait.go │ │ ├── valkey │ │ │ ├── valkey_test.go │ │ │ └── valkey.go │ │ └── yugabytedbsql │ │ ├── yugabytedbsql_test.go │ │ └── yugabytedbsql.go │ └── util │ └── util.go ├── LICENSE ├── logo.png ├── main.go ├── README.md └── tests ├── alloydb │ ├── alloydb_integration_test.go │ └── alloydb_wait_for_operation_test.go ├── alloydbainl │ └── alloydb_ai_nl_integration_test.go ├── alloydbpg │ └── alloydb_pg_integration_test.go ├── auth.go ├── bigquery │ └── bigquery_integration_test.go ├── bigtable │ └── bigtable_integration_test.go ├── cassandra │ └── cassandra_integration_test.go ├── clickhouse │ └── clickhouse_integration_test.go ├── cloudmonitoring │ └── cloud_monitoring_integration_test.go ├── cloudsql │ ├── cloud_sql_create_database_test.go │ ├── cloud_sql_create_users_test.go │ ├── cloud_sql_get_instances_test.go │ ├── cloud_sql_list_databases_test.go │ ├── cloudsql_list_instances_test.go │ └── cloudsql_wait_for_operation_test.go ├── cloudsqlmssql │ ├── cloud_sql_mssql_create_instance_integration_test.go │ └── cloud_sql_mssql_integration_test.go ├── cloudsqlmysql │ ├── cloud_sql_mysql_create_instance_integration_test.go │ └── cloud_sql_mysql_integration_test.go ├── cloudsqlpg │ ├── cloud_sql_pg_create_instances_test.go │ └── cloud_sql_pg_integration_test.go ├── common.go ├── couchbase │ └── couchbase_integration_test.go ├── dataform │ └── dataform_integration_test.go ├── dataplex │ └── dataplex_integration_test.go ├── dgraph │ └── dgraph_integration_test.go ├── firebird │ └── firebird_integration_test.go ├── firestore │ └── firestore_integration_test.go ├── http │ └── http_integration_test.go ├── looker │ └── looker_integration_test.go ├── mongodb │ └── mongodb_integration_test.go ├── mssql │ └── mssql_integration_test.go ├── mysql │ └── mysql_integration_test.go ├── neo4j │ └── neo4j_integration_test.go ├── oceanbase │ └── oceanbase_integration_test.go ├── option.go ├── oracle │ └── oracle_integration_test.go ├── postgres │ └── postgres_integration_test.go ├── redis │ └── redis_test.go ├── server.go ├── source.go ├── spanner │ └── spanner_integration_test.go ├── sqlite │ └── sqlite_integration_test.go ├── tidb │ └── tidb_integration_test.go ├── tool.go ├── trino │ └── trino_integration_test.go ├── utility │ └── wait_integration_test.go ├── valkey │ └── valkey_test.go └── yugabytedb └── yugabytedb_integration_test.go ``` # Files -------------------------------------------------------------------------------- /tests/bigquery/bigquery_integration_test.go: -------------------------------------------------------------------------------- ```go 1 | // Copyright 2025 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package bigquery 16 | 17 | import ( 18 | "bytes" 19 | "context" 20 | "encoding/json" 21 | "fmt" 22 | "io" 23 | "net/http" 24 | "os" 25 | "reflect" 26 | "regexp" 27 | "sort" 28 | "strings" 29 | "testing" 30 | "time" 31 | 32 | bigqueryapi "cloud.google.com/go/bigquery" 33 | "github.com/google/uuid" 34 | "github.com/googleapis/genai-toolbox/internal/sources" 35 | "github.com/googleapis/genai-toolbox/internal/testutils" 36 | "github.com/googleapis/genai-toolbox/tests" 37 | "golang.org/x/oauth2/google" 38 | "google.golang.org/api/googleapi" 39 | "google.golang.org/api/iterator" 40 | "google.golang.org/api/option" 41 | ) 42 | 43 | var ( 44 | BigquerySourceKind = "bigquery" 45 | BigqueryToolKind = "bigquery-sql" 46 | BigqueryProject = os.Getenv("BIGQUERY_PROJECT") 47 | ) 48 | 49 | func getBigQueryVars(t *testing.T) map[string]any { 50 | switch "" { 51 | case BigqueryProject: 52 | t.Fatal("'BIGQUERY_PROJECT' not set") 53 | } 54 | 55 | return map[string]any{ 56 | "kind": BigquerySourceKind, 57 | "project": BigqueryProject, 58 | } 59 | } 60 | 61 | // Copied over from bigquery.go 62 | func initBigQueryConnection(project string) (*bigqueryapi.Client, error) { 63 | ctx := context.Background() 64 | cred, err := google.FindDefaultCredentials(ctx, bigqueryapi.Scope) 65 | if err != nil { 66 | return nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err) 67 | } 68 | 69 | client, err := bigqueryapi.NewClient(ctx, project, option.WithCredentials(cred)) 70 | if err != nil { 71 | return nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err) 72 | } 73 | return client, nil 74 | } 75 | 76 | func TestBigQueryToolEndpoints(t *testing.T) { 77 | sourceConfig := getBigQueryVars(t) 78 | ctx, cancel := context.WithTimeout(context.Background(), 7*time.Minute) 79 | defer cancel() 80 | 81 | var args []string 82 | 83 | client, err := initBigQueryConnection(BigqueryProject) 84 | if err != nil { 85 | t.Fatalf("unable to create Cloud SQL connection pool: %s", err) 86 | } 87 | 88 | // create table name with UUID 89 | datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) 90 | tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) 91 | tableNameParam := fmt.Sprintf("`%s.%s.%s`", 92 | BigqueryProject, 93 | datasetName, 94 | tableName, 95 | ) 96 | tableNameAuth := fmt.Sprintf("`%s.%s.auth_table_%s`", 97 | BigqueryProject, 98 | datasetName, 99 | strings.ReplaceAll(uuid.New().String(), "-", ""), 100 | ) 101 | tableNameTemplateParam := fmt.Sprintf("`%s.%s.template_param_table_%s`", 102 | BigqueryProject, 103 | datasetName, 104 | strings.ReplaceAll(uuid.New().String(), "-", ""), 105 | ) 106 | tableNameDataType := fmt.Sprintf("`%s.%s.datatype_table_%s`", 107 | BigqueryProject, 108 | datasetName, 109 | strings.ReplaceAll(uuid.New().String(), "-", ""), 110 | ) 111 | tableNameForecast := fmt.Sprintf("`%s.%s.forecast_table_%s`", 112 | BigqueryProject, 113 | datasetName, 114 | strings.ReplaceAll(uuid.New().String(), "-", ""), 115 | ) 116 | 117 | tableNameAnalyzeContribution := fmt.Sprintf("`%s.%s.analyze_contribution_table_%s`", 118 | BigqueryProject, 119 | datasetName, 120 | strings.ReplaceAll(uuid.New().String(), "-", ""), 121 | ) 122 | 123 | // set up data for param tool 124 | createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam) 125 | teardownTable1 := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams) 126 | defer teardownTable1(t) 127 | 128 | // set up data for auth tool 129 | createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getBigQueryAuthToolInfo(tableNameAuth) 130 | teardownTable2 := setupBigQueryTable(t, ctx, client, createAuthTableStmt, insertAuthTableStmt, datasetName, tableNameAuth, authTestParams) 131 | defer teardownTable2(t) 132 | 133 | // set up data for data type test tool 134 | createDataTypeTableStmt, insertDataTypeTableStmt, dataTypeToolStmt, arrayDataTypeToolStmt, dataTypeTestParams := getBigQueryDataTypeTestInfo(tableNameDataType) 135 | teardownTable3 := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams) 136 | defer teardownTable3(t) 137 | 138 | // set up data for forecast tool 139 | createForecastTableStmt, insertForecastTableStmt, forecastTestParams := getBigQueryForecastToolInfo(tableNameForecast) 140 | teardownTable4 := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams) 141 | defer teardownTable4(t) 142 | 143 | // set up data for analyze contribution tool 144 | createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, analyzeContributionTestParams := getBigQueryAnalyzeContributionToolInfo(tableNameAnalyzeContribution) 145 | teardownTable5 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, datasetName, tableNameAnalyzeContribution, analyzeContributionTestParams) 146 | defer teardownTable5(t) 147 | 148 | // Write config into a file and pass it to command 149 | toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) 150 | toolsFile = addClientAuthSourceConfig(t, toolsFile) 151 | toolsFile = addBigQuerySqlToolConfig(t, toolsFile, dataTypeToolStmt, arrayDataTypeToolStmt) 152 | toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile) 153 | tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement() 154 | toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BigqueryToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") 155 | 156 | cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) 157 | if err != nil { 158 | t.Fatalf("command initialization returned an error: %s", err) 159 | } 160 | defer cleanup() 161 | 162 | waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 163 | defer cancel() 164 | out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) 165 | if err != nil { 166 | t.Logf("toolbox command logs: \n%s", out) 167 | t.Fatalf("toolbox didn't start successfully: %s", err) 168 | } 169 | 170 | // Get configs for tests 171 | select1Want := "[{\"f0_\":1}]" 172 | invokeParamWant := "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]" 173 | datasetInfoWant := "\"Location\":\"US\",\"DefaultTableExpiration\":0,\"Labels\":null,\"Access\":" 174 | tableInfoWant := "{\"Name\":\"\",\"Location\":\"US\",\"Description\":\"\",\"Schema\":[{\"Name\":\"id\"" 175 | ddlWant := `"Query executed successfully and returned no content."` 176 | dataInsightsWant := `(?s)Schema Resolved.*Retrieval Query.*SQL Generated.*Answer` 177 | // Partial message; the full error message is too long. 178 | mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"query validation failed: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]` 179 | mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"f0_\":1}"}]}}` 180 | createColArray := `["id INT64", "name STRING", "age INT64"]` 181 | selectEmptyWant := `"The query returned 0 rows."` 182 | 183 | // Run tests 184 | tests.RunToolGetTest(t) 185 | tests.RunToolInvokeTest(t, select1Want, tests.DisableOptionalNullParamTest(), tests.EnableClientAuthTest()) 186 | tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want, tests.EnableMcpClientAuthTest()) 187 | tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, 188 | tests.WithCreateColArray(createColArray), 189 | tests.WithDdlWant(ddlWant), 190 | tests.WithSelectEmptyWant(selectEmptyWant), 191 | tests.WithInsert1Want(ddlWant), 192 | ) 193 | 194 | runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, ddlWant) 195 | runBigQueryExecuteSqlToolInvokeDryRunTest(t, datasetName) 196 | runBigQueryForecastToolInvokeTest(t, tableNameForecast) 197 | runBigQueryAnalyzeContributionToolInvokeTest(t, tableNameAnalyzeContribution) 198 | runBigQueryDataTypeTests(t) 199 | runBigQueryListDatasetToolInvokeTest(t, datasetName) 200 | runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant) 201 | runBigQueryListTableIdsToolInvokeTest(t, datasetName, tableName) 202 | runBigQueryGetTableInfoToolInvokeTest(t, datasetName, tableName, tableInfoWant) 203 | runBigQueryConversationalAnalyticsInvokeTest(t, datasetName, tableName, dataInsightsWant) 204 | runBigQuerySearchCatalogToolInvokeTest(t, datasetName, tableName) 205 | } 206 | 207 | func TestBigQueryToolWithDatasetRestriction(t *testing.T) { 208 | ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) 209 | defer cancel() 210 | 211 | client, err := initBigQueryConnection(BigqueryProject) 212 | if err != nil { 213 | t.Fatalf("unable to create BigQuery client: %s", err) 214 | } 215 | 216 | // Create two datasets, one allowed, one not. 217 | baseName := strings.ReplaceAll(uuid.New().String(), "-", "") 218 | allowedDatasetName1 := fmt.Sprintf("allowed_dataset_1_%s", baseName) 219 | allowedDatasetName2 := fmt.Sprintf("allowed_dataset_2_%s", baseName) 220 | disallowedDatasetName := fmt.Sprintf("disallowed_dataset_%s", baseName) 221 | allowedTableName1 := "allowed_table_1" 222 | allowedTableName2 := "allowed_table_2" 223 | disallowedTableName := "disallowed_table" 224 | allowedForecastTableName1 := "allowed_forecast_table_1" 225 | allowedForecastTableName2 := "allowed_forecast_table_2" 226 | disallowedForecastTableName := "disallowed_forecast_table" 227 | 228 | // Setup allowed table 229 | allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1) 230 | createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1) 231 | teardownAllowed1 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt1, "", allowedDatasetName1, allowedTableNameParam1, nil) 232 | defer teardownAllowed1(t) 233 | 234 | allowedTableNameParam2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedTableName2) 235 | createAllowedTableStmt2 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam2) 236 | teardownAllowed2 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt2, "", allowedDatasetName2, allowedTableNameParam2, nil) 237 | defer teardownAllowed2(t) 238 | 239 | // Setup allowed forecast table 240 | allowedForecastTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedForecastTableName1) 241 | createForecastStmt1, insertForecastStmt1, forecastParams1 := getBigQueryForecastToolInfo(allowedForecastTableFullName1) 242 | teardownAllowedForecast1 := setupBigQueryTable(t, ctx, client, createForecastStmt1, insertForecastStmt1, allowedDatasetName1, allowedForecastTableFullName1, forecastParams1) 243 | defer teardownAllowedForecast1(t) 244 | 245 | allowedForecastTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedForecastTableName2) 246 | createForecastStmt2, insertForecastStmt2, forecastParams2 := getBigQueryForecastToolInfo(allowedForecastTableFullName2) 247 | teardownAllowedForecast2 := setupBigQueryTable(t, ctx, client, createForecastStmt2, insertForecastStmt2, allowedDatasetName2, allowedForecastTableFullName2, forecastParams2) 248 | defer teardownAllowedForecast2(t) 249 | 250 | // Setup disallowed table 251 | disallowedTableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedTableName) 252 | createDisallowedTableStmt := fmt.Sprintf("CREATE TABLE %s (id INT64)", disallowedTableNameParam) 253 | teardownDisallowed := setupBigQueryTable(t, ctx, client, createDisallowedTableStmt, "", disallowedDatasetName, disallowedTableNameParam, nil) 254 | defer teardownDisallowed(t) 255 | 256 | // Setup disallowed forecast table 257 | disallowedForecastTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedForecastTableName) 258 | createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedForecastParams := getBigQueryForecastToolInfo(disallowedForecastTableFullName) 259 | teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams) 260 | defer teardownDisallowedForecast(t) 261 | 262 | // Configure source with dataset restriction. 263 | sourceConfig := getBigQueryVars(t) 264 | sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2} 265 | 266 | // Configure tool 267 | toolsConfig := map[string]any{ 268 | "list-dataset-ids-restricted": map[string]any{ 269 | "kind": "bigquery-list-dataset-ids", 270 | "source": "my-instance", 271 | "description": "Tool to list dataset ids", 272 | }, 273 | "list-table-ids-restricted": map[string]any{ 274 | "kind": "bigquery-list-table-ids", 275 | "source": "my-instance", 276 | "description": "Tool to list table within a dataset", 277 | }, 278 | "get-table-info-restricted": map[string]any{ 279 | "kind": "bigquery-get-table-info", 280 | "source": "my-instance", 281 | "description": "Tool to get table info", 282 | }, 283 | "execute-sql-restricted": map[string]any{ 284 | "kind": "bigquery-execute-sql", 285 | "source": "my-instance", 286 | "description": "Tool to execute SQL", 287 | }, 288 | "conversational-analytics-restricted": map[string]any{ 289 | "kind": "bigquery-conversational-analytics", 290 | "source": "my-instance", 291 | "description": "Tool to ask BigQuery conversational analytics", 292 | }, 293 | "forecast-restricted": map[string]any{ 294 | "kind": "bigquery-forecast", 295 | "source": "my-instance", 296 | "description": "Tool to forecast", 297 | }, 298 | } 299 | 300 | // Create config file 301 | config := map[string]any{ 302 | "sources": map[string]any{ 303 | "my-instance": sourceConfig, 304 | }, 305 | "tools": toolsConfig, 306 | } 307 | 308 | // Start server 309 | cmd, cleanup, err := tests.StartCmd(ctx, config) 310 | if err != nil { 311 | t.Fatalf("command initialization returned an error: %s", err) 312 | } 313 | defer cleanup() 314 | 315 | waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 316 | defer cancel() 317 | out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) 318 | if err != nil { 319 | t.Logf("toolbox command logs: \n%s", out) 320 | t.Fatalf("toolbox didn't start successfully: %s", err) 321 | } 322 | 323 | // Run tests 324 | runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2) 325 | runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1) 326 | runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2) 327 | runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName) 328 | runGetTableInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName) 329 | runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam) 330 | runExecuteSqlWithRestriction(t, allowedTableNameParam2, disallowedTableNameParam) 331 | runConversationalAnalyticsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName) 332 | runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName) 333 | runForecastWithRestriction(t, allowedForecastTableFullName1, disallowedForecastTableFullName) 334 | runForecastWithRestriction(t, allowedForecastTableFullName2, disallowedForecastTableFullName) 335 | } 336 | 337 | func TestBigQueryWriteModeAllowed(t *testing.T) { 338 | sourceConfig := getBigQueryVars(t) 339 | sourceConfig["writeMode"] = "allowed" 340 | 341 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) 342 | defer cancel() 343 | 344 | datasetName := fmt.Sprintf("temp_toolbox_test_allowed_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) 345 | 346 | client, err := initBigQueryConnection(BigqueryProject) 347 | if err != nil { 348 | t.Fatalf("unable to create BigQuery connection: %s", err) 349 | } 350 | 351 | dataset := client.Dataset(datasetName) 352 | if err := dataset.Create(ctx, &bigqueryapi.DatasetMetadata{Name: datasetName}); err != nil { 353 | t.Fatalf("Failed to create dataset %q: %v", datasetName, err) 354 | } 355 | defer func() { 356 | if err := dataset.DeleteWithContents(ctx); err != nil { 357 | t.Logf("failed to cleanup dataset %s: %v", datasetName, err) 358 | } 359 | }() 360 | 361 | toolsFile := map[string]any{ 362 | "sources": map[string]any{ 363 | "my-instance": sourceConfig, 364 | }, 365 | "tools": map[string]any{ 366 | "my-exec-sql-tool": map[string]any{ 367 | "kind": "bigquery-execute-sql", 368 | "source": "my-instance", 369 | "description": "Tool to execute sql", 370 | }, 371 | }, 372 | } 373 | 374 | cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) 375 | if err != nil { 376 | t.Fatalf("command initialization returned an error: %s", err) 377 | } 378 | defer cleanup() 379 | 380 | waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 381 | defer cancel() 382 | out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) 383 | if err != nil { 384 | t.Logf("toolbox command logs: \n%s", out) 385 | t.Fatalf("toolbox didn't start successfully: %s", err) 386 | } 387 | 388 | runBigQueryWriteModeAllowedTest(t, datasetName) 389 | } 390 | 391 | func TestBigQueryWriteModeBlocked(t *testing.T) { 392 | sourceConfig := getBigQueryVars(t) 393 | sourceConfig["writeMode"] = "blocked" 394 | 395 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) 396 | defer cancel() 397 | 398 | datasetName := fmt.Sprintf("temp_toolbox_test_blocked_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) 399 | tableName := fmt.Sprintf("param_table_blocked_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) 400 | tableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, datasetName, tableName) 401 | 402 | client, err := initBigQueryConnection(BigqueryProject) 403 | if err != nil { 404 | t.Fatalf("unable to create BigQuery connection: %s", err) 405 | } 406 | createParamTableStmt, insertParamTableStmt, _, _, _, _, paramTestParams := getBigQueryParamToolInfo(tableNameParam) 407 | teardownTable := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams) 408 | defer teardownTable(t) 409 | 410 | toolsFile := map[string]any{ 411 | "sources": map[string]any{"my-instance": sourceConfig}, 412 | "tools": map[string]any{ 413 | "my-exec-sql-tool": map[string]any{"kind": "bigquery-execute-sql", "source": "my-instance", "description": "Tool to execute sql"}, 414 | }, 415 | } 416 | 417 | cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) 418 | if err != nil { 419 | t.Fatalf("command initialization returned an error: %s", err) 420 | } 421 | defer cleanup() 422 | 423 | waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 424 | defer cancel() 425 | out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) 426 | if err != nil { 427 | t.Logf("toolbox command logs: \n%s", out) 428 | t.Fatalf("toolbox didn't start successfully: %s", err) 429 | } 430 | 431 | runBigQueryWriteModeBlockedTest(t, tableNameParam, datasetName) 432 | } 433 | 434 | func TestBigQueryWriteModeProtected(t *testing.T) { 435 | sourceConfig := getBigQueryVars(t) 436 | sourceConfig["writeMode"] = "protected" 437 | 438 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) 439 | defer cancel() 440 | 441 | permanentDatasetName := fmt.Sprintf("perm_dataset_protected_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) 442 | client, err := initBigQueryConnection(BigqueryProject) 443 | if err != nil { 444 | t.Fatalf("unable to create BigQuery connection: %s", err) 445 | } 446 | dataset := client.Dataset(permanentDatasetName) 447 | if err := dataset.Create(ctx, &bigqueryapi.DatasetMetadata{Name: permanentDatasetName}); err != nil { 448 | t.Fatalf("Failed to create dataset %q: %v", permanentDatasetName, err) 449 | } 450 | defer func() { 451 | if err := dataset.DeleteWithContents(ctx); err != nil { 452 | t.Logf("failed to cleanup dataset %s: %v", permanentDatasetName, err) 453 | } 454 | }() 455 | 456 | toolsFile := map[string]any{ 457 | "sources": map[string]any{"my-instance": sourceConfig}, 458 | "tools": map[string]any{ 459 | "my-exec-sql-tool": map[string]any{"kind": "bigquery-execute-sql", "source": "my-instance", "description": "Tool to execute sql"}, 460 | "my-sql-tool-protected": map[string]any{ 461 | "kind": "bigquery-sql", 462 | "source": "my-instance", 463 | "description": "Tool to query from the session", 464 | "statement": "SELECT * FROM my_shared_temp_table", 465 | }, 466 | "my-forecast-tool-protected": map[string]any{ 467 | "kind": "bigquery-forecast", 468 | "source": "my-instance", 469 | "description": "Tool to forecast from session temp table", 470 | }, 471 | "my-analyze-contribution-tool-protected": map[string]any{ 472 | "kind": "bigquery-analyze-contribution", 473 | "source": "my-instance", 474 | "description": "Tool to analyze contribution from session temp table", 475 | }, 476 | }, 477 | } 478 | 479 | cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) 480 | if err != nil { 481 | t.Fatalf("command initialization returned an error: %s", err) 482 | } 483 | defer cleanup() 484 | 485 | waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 486 | defer cancel() 487 | out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) 488 | if err != nil { 489 | t.Logf("toolbox command logs: \n%s", out) 490 | t.Fatalf("toolbox didn't start successfully: %s", err) 491 | } 492 | 493 | runBigQueryWriteModeProtectedTest(t, permanentDatasetName) 494 | } 495 | 496 | // getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind 497 | func getBigQueryParamToolInfo(tableName string) (string, string, string, string, string, string, []bigqueryapi.QueryParameter) { 498 | createStatement := fmt.Sprintf(` 499 | CREATE TABLE IF NOT EXISTS %s (id INT64, name STRING);`, tableName) 500 | insertStatement := fmt.Sprintf(` 501 | INSERT INTO %s (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, NULL);`, tableName) 502 | toolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE id = ? OR name = ? ORDER BY id;`, tableName) 503 | idToolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE id = ? ORDER BY id;`, tableName) 504 | nameToolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE name = ? ORDER BY id;`, tableName) 505 | arrayToolStatememt := fmt.Sprintf(`SELECT * FROM %s WHERE id IN UNNEST(@idArray) AND name IN UNNEST(@nameArray) ORDER BY id;`, tableName) 506 | params := []bigqueryapi.QueryParameter{ 507 | {Value: int64(1)}, {Value: "Alice"}, 508 | {Value: int64(2)}, {Value: "Jane"}, 509 | {Value: int64(3)}, {Value: "Sid"}, 510 | {Value: int64(4)}, 511 | } 512 | return createStatement, insertStatement, toolStatement, idToolStatement, nameToolStatement, arrayToolStatememt, params 513 | } 514 | 515 | // getBigQueryAuthToolInfo returns statements and param of my-auth-tool for bigquery kind 516 | func getBigQueryAuthToolInfo(tableName string) (string, string, string, []bigqueryapi.QueryParameter) { 517 | createStatement := fmt.Sprintf(` 518 | CREATE TABLE IF NOT EXISTS %s (id INT64, name STRING, email STRING)`, tableName) 519 | insertStatement := fmt.Sprintf(` 520 | INSERT INTO %s (id, name, email) VALUES (?, ?, ?), (?, ?, ?)`, tableName) 521 | toolStatement := fmt.Sprintf(` 522 | SELECT name FROM %s WHERE email = ?`, tableName) 523 | params := []bigqueryapi.QueryParameter{ 524 | {Value: int64(1)}, {Value: "Alice"}, {Value: tests.ServiceAccountEmail}, 525 | {Value: int64(2)}, {Value: "Jane"}, {Value: "[email protected]"}, 526 | } 527 | return createStatement, insertStatement, toolStatement, params 528 | } 529 | 530 | // getBigQueryDataTypeTestInfo returns statements and params for data type tests. 531 | func getBigQueryDataTypeTestInfo(tableName string) (string, string, string, string, []bigqueryapi.QueryParameter) { 532 | createStatement := fmt.Sprintf(` 533 | CREATE TABLE IF NOT EXISTS %s (id INT64, int_val INT64, string_val STRING, float_val FLOAT64, bool_val BOOL);`, tableName) 534 | insertStatement := fmt.Sprintf(` 535 | INSERT INTO %s (id, int_val, string_val, float_val, bool_val) VALUES (?, ?, ?, ?, ?), (?, ?, ?, ?, ?), (?, ?, ?, ?, ?);`, tableName) 536 | toolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE int_val = ? AND string_val = ? AND float_val = ? AND bool_val = ?;`, tableName) 537 | arrayToolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE int_val IN UNNEST(@int_array) AND string_val IN UNNEST(@string_array) AND float_val IN UNNEST(@float_array) AND bool_val IN UNNEST(@bool_array) ORDER BY id;`, tableName) 538 | params := []bigqueryapi.QueryParameter{ 539 | {Value: int64(1)}, {Value: int64(123)}, {Value: "hello"}, {Value: 3.14}, {Value: true}, 540 | {Value: int64(2)}, {Value: int64(-456)}, {Value: "world"}, {Value: -0.55}, {Value: false}, 541 | {Value: int64(3)}, {Value: int64(789)}, {Value: "test"}, {Value: 100.1}, {Value: true}, 542 | } 543 | return createStatement, insertStatement, toolStatement, arrayToolStatement, params 544 | } 545 | 546 | // getBigQueryForecastToolInfo returns statements and params for the forecast tool. 547 | func getBigQueryForecastToolInfo(tableName string) (string, string, []bigqueryapi.QueryParameter) { 548 | createStatement := fmt.Sprintf(` 549 | CREATE TABLE IF NOT EXISTS %s (ts TIMESTAMP, data FLOAT64, id STRING);`, tableName) 550 | insertStatement := fmt.Sprintf(` 551 | INSERT INTO %s (ts, data, id) VALUES 552 | (?, ?, ?), (?, ?, ?), (?, ?, ?), 553 | (?, ?, ?), (?, ?, ?), (?, ?, ?);`, tableName) 554 | params := []bigqueryapi.QueryParameter{ 555 | {Value: "2025-01-01T00:00:00Z"}, {Value: 10.0}, {Value: "a"}, 556 | {Value: "2025-01-01T01:00:00Z"}, {Value: 11.0}, {Value: "a"}, 557 | {Value: "2025-01-01T02:00:00Z"}, {Value: 12.0}, {Value: "a"}, 558 | {Value: "2025-01-01T00:00:00Z"}, {Value: 20.0}, {Value: "b"}, 559 | {Value: "2025-01-01T01:00:00Z"}, {Value: 21.0}, {Value: "b"}, 560 | {Value: "2025-01-01T02:00:00Z"}, {Value: 22.0}, {Value: "b"}, 561 | } 562 | return createStatement, insertStatement, params 563 | } 564 | 565 | // getBigQueryAnalyzeContributionToolInfo returns statements and params for the analyze-contribution tool. 566 | func getBigQueryAnalyzeContributionToolInfo(tableName string) (string, string, []bigqueryapi.QueryParameter) { 567 | createStatement := fmt.Sprintf(` 568 | CREATE TABLE IF NOT EXISTS %s (dim1 STRING, dim2 STRING, is_test BOOL, metric FLOAT64);`, tableName) 569 | insertStatement := fmt.Sprintf(` 570 | INSERT INTO %s (dim1, dim2, is_test, metric) VALUES 571 | (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?);`, tableName) 572 | params := []bigqueryapi.QueryParameter{ 573 | {Value: "a"}, {Value: "x"}, {Value: true}, {Value: 100.0}, 574 | {Value: "a"}, {Value: "x"}, {Value: false}, {Value: 110.0}, 575 | {Value: "a"}, {Value: "y"}, {Value: true}, {Value: 120.0}, 576 | {Value: "a"}, {Value: "y"}, {Value: false}, {Value: 100.0}, 577 | {Value: "b"}, {Value: "x"}, {Value: true}, {Value: 40.0}, 578 | {Value: "b"}, {Value: "x"}, {Value: false}, {Value: 100.0}, 579 | {Value: "b"}, {Value: "y"}, {Value: true}, {Value: 60.0}, 580 | {Value: "b"}, {Value: "y"}, {Value: false}, {Value: 60.0}, 581 | } 582 | return createStatement, insertStatement, params 583 | } 584 | 585 | // getBigQueryTmplToolStatement returns statements for template parameter test cases for bigquery kind 586 | func getBigQueryTmplToolStatement() (string, string) { 587 | tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ? ORDER BY id" 588 | tmplSelectFilterCombined := "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = ? ORDER BY id" 589 | return tmplSelectCombined, tmplSelectFilterCombined 590 | } 591 | 592 | func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) { 593 | // Create dataset 594 | dataset := client.Dataset(datasetName) 595 | _, err := dataset.Metadata(ctx) 596 | 597 | if err != nil { 598 | apiErr, ok := err.(*googleapi.Error) 599 | if !ok || apiErr.Code != 404 { 600 | t.Fatalf("Failed to check dataset %q existence: %v", datasetName, err) 601 | } 602 | metadataToCreate := &bigqueryapi.DatasetMetadata{Name: datasetName} 603 | if err := dataset.Create(ctx, metadataToCreate); err != nil { 604 | t.Fatalf("Failed to create dataset %q: %v", datasetName, err) 605 | } 606 | } 607 | 608 | // Create table 609 | createJob, err := client.Query(createStatement).Run(ctx) 610 | 611 | if err != nil { 612 | t.Fatalf("Failed to start create table job for %s: %v", tableName, err) 613 | } 614 | createStatus, err := createJob.Wait(ctx) 615 | if err != nil { 616 | t.Fatalf("Failed to wait for create table job for %s: %v", tableName, err) 617 | } 618 | if err := createStatus.Err(); err != nil { 619 | t.Fatalf("Create table job for %s failed: %v", tableName, err) 620 | } 621 | 622 | if len(params) > 0 { 623 | // Insert test data 624 | insertQuery := client.Query(insertStatement) 625 | insertQuery.Parameters = params 626 | insertJob, err := insertQuery.Run(ctx) 627 | if err != nil { 628 | t.Fatalf("Failed to start insert job for %s: %v", tableName, err) 629 | } 630 | insertStatus, err := insertJob.Wait(ctx) 631 | if err != nil { 632 | t.Fatalf("Failed to wait for insert job for %s: %v", tableName, err) 633 | } 634 | if err := insertStatus.Err(); err != nil { 635 | t.Fatalf("Insert job for %s failed: %v", tableName, err) 636 | } 637 | } 638 | 639 | return func(t *testing.T) { 640 | // tear down table 641 | dropSQL := fmt.Sprintf("drop table %s", tableName) 642 | dropJob, err := client.Query(dropSQL).Run(ctx) 643 | if err != nil { 644 | t.Errorf("Failed to start drop table job for %s: %v", tableName, err) 645 | return 646 | } 647 | dropStatus, err := dropJob.Wait(ctx) 648 | if err != nil { 649 | t.Errorf("Failed to wait for drop table job for %s: %v", tableName, err) 650 | return 651 | } 652 | if err := dropStatus.Err(); err != nil { 653 | t.Errorf("Error dropping table %s: %v", tableName, err) 654 | } 655 | 656 | // tear down dataset 657 | datasetToTeardown := client.Dataset(datasetName) 658 | tablesIterator := datasetToTeardown.Tables(ctx) 659 | _, err = tablesIterator.Next() 660 | 661 | if err == iterator.Done { 662 | if err := datasetToTeardown.Delete(ctx); err != nil { 663 | t.Errorf("Failed to delete dataset %s: %v", datasetName, err) 664 | } 665 | } else if err != nil { 666 | t.Errorf("Failed to list tables in dataset %s to check emptiness: %v.", datasetName, err) 667 | } 668 | } 669 | } 670 | 671 | func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[string]any { 672 | tools, ok := config["tools"].(map[string]any) 673 | if !ok { 674 | t.Fatalf("unable to get tools from config") 675 | } 676 | tools["my-exec-sql-tool"] = map[string]any{ 677 | "kind": "bigquery-execute-sql", 678 | "source": "my-instance", 679 | "description": "Tool to execute sql", 680 | } 681 | tools["my-auth-exec-sql-tool"] = map[string]any{ 682 | "kind": "bigquery-execute-sql", 683 | "source": "my-instance", 684 | "description": "Tool to execute sql", 685 | "authRequired": []string{ 686 | "my-google-auth", 687 | }, 688 | } 689 | tools["my-client-auth-exec-sql-tool"] = map[string]any{ 690 | "kind": "bigquery-execute-sql", 691 | "source": "my-client-auth-source", 692 | "description": "Tool to execute sql", 693 | } 694 | tools["my-forecast-tool"] = map[string]any{ 695 | "kind": "bigquery-forecast", 696 | "source": "my-instance", 697 | "description": "Tool to forecast time series data.", 698 | } 699 | tools["my-auth-forecast-tool"] = map[string]any{ 700 | "kind": "bigquery-forecast", 701 | "source": "my-instance", 702 | "description": "Tool to forecast time series data with auth.", 703 | "authRequired": []string{ 704 | "my-google-auth", 705 | }, 706 | } 707 | tools["my-client-auth-forecast-tool"] = map[string]any{ 708 | "kind": "bigquery-forecast", 709 | "source": "my-client-auth-source", 710 | "description": "Tool to forecast time series data with auth.", 711 | } 712 | tools["my-analyze-contribution-tool"] = map[string]any{ 713 | "kind": "bigquery-analyze-contribution", 714 | "source": "my-instance", 715 | "description": "Tool to analyze contribution.", 716 | } 717 | tools["my-auth-analyze-contribution-tool"] = map[string]any{ 718 | "kind": "bigquery-analyze-contribution", 719 | "source": "my-instance", 720 | "description": "Tool to analyze contribution with auth.", 721 | "authRequired": []string{ 722 | "my-google-auth", 723 | }, 724 | } 725 | tools["my-client-auth-analyze-contribution-tool"] = map[string]any{ 726 | "kind": "bigquery-analyze-contribution", 727 | "source": "my-client-auth-source", 728 | "description": "Tool to analyze contribution with auth.", 729 | } 730 | tools["my-list-dataset-ids-tool"] = map[string]any{ 731 | "kind": "bigquery-list-dataset-ids", 732 | "source": "my-instance", 733 | "description": "Tool to list dataset", 734 | } 735 | tools["my-auth-list-dataset-ids-tool"] = map[string]any{ 736 | "kind": "bigquery-list-dataset-ids", 737 | "source": "my-instance", 738 | "description": "Tool to list dataset", 739 | "authRequired": []string{ 740 | "my-google-auth", 741 | }, 742 | } 743 | tools["my-client-auth-list-dataset-ids-tool"] = map[string]any{ 744 | "kind": "bigquery-list-dataset-ids", 745 | "source": "my-client-auth-source", 746 | "description": "Tool to list dataset", 747 | } 748 | tools["my-get-dataset-info-tool"] = map[string]any{ 749 | "kind": "bigquery-get-dataset-info", 750 | "source": "my-instance", 751 | "description": "Tool to show dataset metadata", 752 | } 753 | tools["my-auth-get-dataset-info-tool"] = map[string]any{ 754 | "kind": "bigquery-get-dataset-info", 755 | "source": "my-instance", 756 | "description": "Tool to show dataset metadata", 757 | "authRequired": []string{ 758 | "my-google-auth", 759 | }, 760 | } 761 | tools["my-client-auth-get-dataset-info-tool"] = map[string]any{ 762 | "kind": "bigquery-get-dataset-info", 763 | "source": "my-client-auth-source", 764 | "description": "Tool to show dataset metadata", 765 | } 766 | tools["my-list-table-ids-tool"] = map[string]any{ 767 | "kind": "bigquery-list-table-ids", 768 | "source": "my-instance", 769 | "description": "Tool to list table within a dataset", 770 | } 771 | tools["my-auth-list-table-ids-tool"] = map[string]any{ 772 | "kind": "bigquery-list-table-ids", 773 | "source": "my-instance", 774 | "description": "Tool to list table within a dataset", 775 | "authRequired": []string{ 776 | "my-google-auth", 777 | }, 778 | } 779 | tools["my-client-auth-list-table-ids-tool"] = map[string]any{ 780 | "kind": "bigquery-list-table-ids", 781 | "source": "my-client-auth-source", 782 | "description": "Tool to list table within a dataset", 783 | } 784 | tools["my-get-table-info-tool"] = map[string]any{ 785 | "kind": "bigquery-get-table-info", 786 | "source": "my-instance", 787 | "description": "Tool to show dataset metadata", 788 | } 789 | tools["my-auth-get-table-info-tool"] = map[string]any{ 790 | "kind": "bigquery-get-table-info", 791 | "source": "my-instance", 792 | "description": "Tool to show dataset metadata", 793 | "authRequired": []string{ 794 | "my-google-auth", 795 | }, 796 | } 797 | tools["my-client-auth-get-table-info-tool"] = map[string]any{ 798 | "kind": "bigquery-get-table-info", 799 | "source": "my-client-auth-source", 800 | "description": "Tool to show dataset metadata", 801 | } 802 | tools["my-conversational-analytics-tool"] = map[string]any{ 803 | "kind": "bigquery-conversational-analytics", 804 | "source": "my-instance", 805 | "description": "Tool to ask BigQuery conversational analytics", 806 | } 807 | tools["my-auth-conversational-analytics-tool"] = map[string]any{ 808 | "kind": "bigquery-conversational-analytics", 809 | "source": "my-instance", 810 | "description": "Tool to ask BigQuery conversational analytics", 811 | "authRequired": []string{ 812 | "my-google-auth", 813 | }, 814 | } 815 | tools["my-client-auth-conversational-analytics-tool"] = map[string]any{ 816 | "kind": "bigquery-conversational-analytics", 817 | "source": "my-client-auth-source", 818 | "description": "Tool to ask BigQuery conversational analytics", 819 | } 820 | tools["my-search-catalog-tool"] = map[string]any{ 821 | "kind": "bigquery-search-catalog", 822 | "source": "my-instance", 823 | "description": "Tool to search the BiqQuery catalog", 824 | } 825 | tools["my-auth-search-catalog-tool"] = map[string]any{ 826 | "kind": "bigquery-search-catalog", 827 | "source": "my-instance", 828 | "description": "Tool to search the BiqQuery catalog", 829 | "authRequired": []string{ 830 | "my-google-auth", 831 | }, 832 | } 833 | tools["my-client-auth-search-catalog-tool"] = map[string]any{ 834 | "kind": "bigquery-search-catalog", 835 | "source": "my-client-auth-source", 836 | "description": "Tool to search the BiqQuery catalog", 837 | } 838 | config["tools"] = tools 839 | return config 840 | } 841 | 842 | func addClientAuthSourceConfig(t *testing.T, config map[string]any) map[string]any { 843 | sources, ok := config["sources"].(map[string]any) 844 | if !ok { 845 | t.Fatalf("unable to get sources from config") 846 | } 847 | sources["my-client-auth-source"] = map[string]any{ 848 | "kind": BigquerySourceKind, 849 | "project": BigqueryProject, 850 | "useClientOAuth": true, 851 | } 852 | config["sources"] = sources 853 | return config 854 | } 855 | 856 | func addBigQuerySqlToolConfig(t *testing.T, config map[string]any, toolStatement, arrayToolStatement string) map[string]any { 857 | tools, ok := config["tools"].(map[string]any) 858 | if !ok { 859 | t.Fatalf("unable to get tools from config") 860 | } 861 | tools["my-scalar-datatype-tool"] = map[string]any{ 862 | "kind": "bigquery-sql", 863 | "source": "my-instance", 864 | "description": "Tool to test various scalar data types.", 865 | "statement": toolStatement, 866 | "parameters": []any{ 867 | map[string]any{"name": "int_val", "type": "integer", "description": "an integer value"}, 868 | map[string]any{"name": "string_val", "type": "string", "description": "a string value"}, 869 | map[string]any{"name": "float_val", "type": "float", "description": "a float value"}, 870 | map[string]any{"name": "bool_val", "type": "boolean", "description": "a boolean value"}, 871 | }, 872 | } 873 | tools["my-array-datatype-tool"] = map[string]any{ 874 | "kind": "bigquery-sql", 875 | "source": "my-instance", 876 | "description": "Tool to test various array data types.", 877 | "statement": arrayToolStatement, 878 | "parameters": []any{ 879 | map[string]any{"name": "int_array", "type": "array", "description": "an array of integer values", "items": map[string]any{"name": "item", "type": "integer", "description": "desc"}}, 880 | map[string]any{"name": "string_array", "type": "array", "description": "an array of string values", "items": map[string]any{"name": "item", "type": "string", "description": "desc"}}, 881 | map[string]any{"name": "float_array", "type": "array", "description": "an array of float values", "items": map[string]any{"name": "item", "type": "float", "description": "desc"}}, 882 | map[string]any{"name": "bool_array", "type": "array", "description": "an array of boolean values", "items": map[string]any{"name": "item", "type": "boolean", "description": "desc"}}, 883 | }, 884 | } 885 | tools["my-client-auth-tool"] = map[string]any{ 886 | "kind": "bigquery-sql", 887 | "source": "my-client-auth-source", 888 | "description": "Tool to test client authorization.", 889 | "statement": "SELECT 1", 890 | } 891 | config["tools"] = tools 892 | return config 893 | } 894 | 895 | func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWant, tableNameParam, ddlWant string) { 896 | // Get ID token 897 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 898 | if err != nil { 899 | t.Fatalf("error getting Google ID token: %s", err) 900 | } 901 | 902 | // Get access token 903 | accessToken, err := sources.GetIAMAccessToken(t.Context()) 904 | if err != nil { 905 | t.Fatalf("error getting access token from ADC: %s", err) 906 | } 907 | accessToken = "Bearer " + accessToken 908 | 909 | // Test tool invoke endpoint 910 | invokeTcs := []struct { 911 | name string 912 | api string 913 | requestHeader map[string]string 914 | requestBody io.Reader 915 | want string 916 | isErr bool 917 | }{ 918 | { 919 | name: "invoke my-exec-sql-tool without body", 920 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 921 | requestHeader: map[string]string{}, 922 | requestBody: bytes.NewBuffer([]byte(`{}`)), 923 | isErr: true, 924 | }, 925 | { 926 | name: "invoke my-exec-sql-tool", 927 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 928 | requestHeader: map[string]string{}, 929 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)), 930 | want: select1Want, 931 | isErr: false, 932 | }, 933 | { 934 | name: "invoke my-exec-sql-tool create table", 935 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 936 | requestHeader: map[string]string{}, 937 | requestBody: bytes.NewBuffer([]byte(`{"sql":"CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)"}`)), 938 | want: ddlWant, 939 | isErr: true, 940 | }, 941 | { 942 | name: "invoke my-exec-sql-tool with data present in table", 943 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 944 | requestHeader: map[string]string{}, 945 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"sql\":\"SELECT * FROM %s WHERE id = 3 OR name = 'Alice' ORDER BY id\"}", tableNameParam))), 946 | want: invokeParamWant, 947 | isErr: false, 948 | }, 949 | { 950 | name: "invoke my-exec-sql-tool with no matching rows", 951 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 952 | requestHeader: map[string]string{}, 953 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"sql\":\"SELECT * FROM %s WHERE id = 999\"}", tableNameParam))), 954 | want: `"The query returned 0 rows."`, 955 | isErr: false, 956 | }, 957 | { 958 | name: "invoke my-exec-sql-tool drop table", 959 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 960 | requestHeader: map[string]string{}, 961 | requestBody: bytes.NewBuffer([]byte(`{"sql":"DROP TABLE t"}`)), 962 | want: ddlWant, 963 | isErr: true, 964 | }, 965 | { 966 | name: "invoke my-exec-sql-tool insert entry", 967 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 968 | requestHeader: map[string]string{}, 969 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"sql\":\"INSERT INTO %s (id, name) VALUES (4, 'test_name')\"}", tableNameParam))), 970 | want: ddlWant, 971 | isErr: false, 972 | }, 973 | { 974 | name: "invoke my-exec-sql-tool without body", 975 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 976 | requestHeader: map[string]string{}, 977 | requestBody: bytes.NewBuffer([]byte(`{}`)), 978 | isErr: true, 979 | }, 980 | { 981 | name: "Invoke my-auth-exec-sql-tool with auth token", 982 | api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", 983 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 984 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)), 985 | isErr: false, 986 | want: select1Want, 987 | }, 988 | { 989 | name: "Invoke my-auth-exec-sql-tool with invalid auth token", 990 | api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", 991 | requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, 992 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)), 993 | isErr: true, 994 | }, 995 | { 996 | name: "Invoke my-auth-exec-sql-tool without auth token", 997 | api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", 998 | requestHeader: map[string]string{}, 999 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)), 1000 | isErr: true, 1001 | }, 1002 | { 1003 | name: "Invoke my-client-auth-exec-sql-tool with auth token", 1004 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-exec-sql-tool/invoke", 1005 | requestHeader: map[string]string{"Authorization": accessToken}, 1006 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)), 1007 | want: "[{\"f0_\":1}]", 1008 | isErr: false, 1009 | }, 1010 | { 1011 | name: "Invoke my-client-auth-exec-sql-tool without auth token", 1012 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-exec-sql-tool/invoke", 1013 | requestHeader: map[string]string{}, 1014 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)), 1015 | isErr: true, 1016 | }, 1017 | { 1018 | 1019 | name: "Invoke my-client-auth-exec-sql-tool with invalid auth token", 1020 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-exec-sql-tool/invoke", 1021 | requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, 1022 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)), 1023 | isErr: true, 1024 | }, 1025 | } 1026 | for _, tc := range invokeTcs { 1027 | t.Run(tc.name, func(t *testing.T) { 1028 | // Send Tool invocation request 1029 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 1030 | if err != nil { 1031 | t.Fatalf("unable to create request: %s", err) 1032 | } 1033 | req.Header.Add("Content-type", "application/json") 1034 | for k, v := range tc.requestHeader { 1035 | req.Header.Add(k, v) 1036 | } 1037 | resp, err := http.DefaultClient.Do(req) 1038 | if err != nil { 1039 | t.Fatalf("unable to send request: %s", err) 1040 | } 1041 | defer resp.Body.Close() 1042 | 1043 | if resp.StatusCode != http.StatusOK { 1044 | if tc.isErr { 1045 | return 1046 | } 1047 | bodyBytes, _ := io.ReadAll(resp.Body) 1048 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 1049 | } 1050 | 1051 | // Check response body 1052 | var body map[string]interface{} 1053 | err = json.NewDecoder(resp.Body).Decode(&body) 1054 | if err != nil { 1055 | t.Fatalf("error parsing response body") 1056 | } 1057 | 1058 | got, ok := body["result"].(string) 1059 | if !ok { 1060 | t.Fatalf("unable to find result in response body") 1061 | } 1062 | 1063 | if got != tc.want { 1064 | t.Fatalf("unexpected value: got %q, want %q", got, tc.want) 1065 | } 1066 | }) 1067 | } 1068 | } 1069 | 1070 | // runInvokeRequest sends a POST request to the given API endpoint and returns the response and parsed JSON body. 1071 | func runInvokeRequest(t *testing.T, api, body string, headers map[string]string) (*http.Response, map[string]interface{}) { 1072 | t.Helper() 1073 | req, err := http.NewRequest(http.MethodPost, api, bytes.NewBufferString(body)) 1074 | if err != nil { 1075 | t.Fatalf("failed to create request: %v", err) 1076 | } 1077 | req.Header.Set("Content-Type", "application/json") 1078 | for k, v := range headers { 1079 | req.Header.Add(k, v) 1080 | } 1081 | 1082 | resp, err := http.DefaultClient.Do(req) 1083 | if err != nil { 1084 | t.Fatalf("failed to send request: %v", err) 1085 | } 1086 | 1087 | var result map[string]interface{} 1088 | // Use a TeeReader to be able to read the body multiple times (for logging on failure) 1089 | bodyBytes, err := io.ReadAll(resp.Body) 1090 | if err != nil { 1091 | t.Fatalf("failed to read response body: %v", err) 1092 | } 1093 | resp.Body.Close() // Close original body 1094 | resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Replace with a new reader 1095 | 1096 | if err := json.Unmarshal(bodyBytes, &result); err != nil { 1097 | t.Logf("Failed to decode response body: %s", string(bodyBytes)) 1098 | t.Fatalf("failed to decode response: %v", err) 1099 | } 1100 | return resp, result 1101 | } 1102 | 1103 | func runBigQueryWriteModeAllowedTest(t *testing.T, datasetName string) { 1104 | t.Run("CREATE TABLE should succeed", func(t *testing.T) { 1105 | sql := fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName) 1106 | body := fmt.Sprintf(`{"sql": "%s"}`, sql) 1107 | resp, result := runInvokeRequest(t, "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", body, nil) 1108 | defer resp.Body.Close() 1109 | 1110 | if resp.StatusCode != http.StatusOK { 1111 | bodyBytes, _ := io.ReadAll(resp.Body) 1112 | t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusOK, string(bodyBytes)) 1113 | } 1114 | 1115 | resStr, ok := result["result"].(string) 1116 | if !ok { 1117 | t.Fatalf("expected 'result' field in response, got %v", result) 1118 | } 1119 | if resStr != `"Query executed successfully and returned no content."` { 1120 | t.Errorf("unexpected result: got %q, want %q", resStr, `"Query executed successfully and returned no content."`) 1121 | } 1122 | }) 1123 | } 1124 | 1125 | func runBigQueryWriteModeBlockedTest(t *testing.T, tableNameParam, datasetName string) { 1126 | testCases := []struct { 1127 | name string 1128 | sql string 1129 | wantStatusCode int 1130 | wantInError string 1131 | wantResult string 1132 | }{ 1133 | {"SELECT statement should succeed", fmt.Sprintf("SELECT * FROM %s WHERE id = 1", tableNameParam), http.StatusOK, "", `[{"id":1,"name":"Alice"}]`}, 1134 | {"INSERT statement should fail", fmt.Sprintf("INSERT INTO %s (id, name) VALUES (10, 'test')", tableNameParam), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""}, 1135 | {"CREATE TABLE statement should fail", fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""}, 1136 | } 1137 | 1138 | for _, tc := range testCases { 1139 | t.Run(tc.name, func(t *testing.T) { 1140 | body := fmt.Sprintf(`{"sql": "%s"}`, tc.sql) 1141 | resp, result := runInvokeRequest(t, "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", body, nil) 1142 | defer resp.Body.Close() 1143 | 1144 | if resp.StatusCode != tc.wantStatusCode { 1145 | bodyBytes, _ := io.ReadAll(resp.Body) 1146 | t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) 1147 | } 1148 | 1149 | if tc.wantInError != "" { 1150 | errStr, ok := result["error"].(string) 1151 | if !ok { 1152 | t.Fatalf("expected 'error' field in response, got %v", result) 1153 | } 1154 | if !strings.Contains(errStr, tc.wantInError) { 1155 | t.Fatalf("expected error message to contain %q, but got %q", tc.wantInError, errStr) 1156 | } 1157 | } 1158 | if tc.wantResult != "" { 1159 | resStr, ok := result["result"].(string) 1160 | if !ok { 1161 | t.Fatalf("expected 'result' field in response, got %v", result) 1162 | } 1163 | if resStr != tc.wantResult { 1164 | t.Fatalf("unexpected result: got %q, want %q", resStr, tc.wantResult) 1165 | } 1166 | } 1167 | }) 1168 | } 1169 | } 1170 | 1171 | func runBigQueryWriteModeProtectedTest(t *testing.T, permanentDatasetName string) { 1172 | testCases := []struct { 1173 | name string 1174 | toolName string 1175 | requestBody string 1176 | wantStatusCode int 1177 | wantInError string 1178 | wantResult string 1179 | }{ 1180 | { 1181 | name: "CREATE TABLE to permanent dataset should fail", 1182 | toolName: "my-exec-sql-tool", 1183 | requestBody: fmt.Sprintf(`{"sql": "CREATE TABLE %s.new_table (x INT64)"}`, permanentDatasetName), 1184 | wantStatusCode: http.StatusBadRequest, 1185 | wantInError: "protected write mode only supports SELECT statements, or write operations in the anonymous dataset", 1186 | wantResult: "", 1187 | }, 1188 | { 1189 | name: "CREATE TEMP TABLE should succeed", 1190 | toolName: "my-exec-sql-tool", 1191 | requestBody: `{"sql": "CREATE TEMP TABLE my_shared_temp_table (x INT64)"}`, 1192 | wantStatusCode: http.StatusOK, 1193 | wantInError: "", 1194 | wantResult: `"Query executed successfully and returned no content."`, 1195 | }, 1196 | { 1197 | name: "INSERT into TEMP TABLE should succeed", 1198 | toolName: "my-exec-sql-tool", 1199 | requestBody: `{"sql": "INSERT INTO my_shared_temp_table (x) VALUES (42)"}`, 1200 | wantStatusCode: http.StatusOK, 1201 | wantInError: "", 1202 | wantResult: `"Query executed successfully and returned no content."`, 1203 | }, 1204 | { 1205 | name: "SELECT from TEMP TABLE with exec-sql should succeed", 1206 | toolName: "my-exec-sql-tool", 1207 | requestBody: `{"sql": "SELECT * FROM my_shared_temp_table"}`, 1208 | wantStatusCode: http.StatusOK, 1209 | wantInError: "", 1210 | wantResult: `[{"x":42}]`, 1211 | }, 1212 | { 1213 | name: "SELECT from TEMP TABLE with sql-tool should succeed", 1214 | toolName: "my-sql-tool-protected", 1215 | requestBody: `{}`, 1216 | wantStatusCode: http.StatusOK, 1217 | wantInError: "", 1218 | wantResult: `[{"x":42}]`, 1219 | }, 1220 | { 1221 | name: "CREATE TEMP TABLE for forecast should succeed", 1222 | toolName: "my-exec-sql-tool", 1223 | requestBody: `{"sql": "CREATE TEMP TABLE forecast_temp_table (ts TIMESTAMP, data FLOAT64) AS SELECT TIMESTAMP('2025-01-01T00:00:00Z') AS ts, 10.0 AS data UNION ALL SELECT TIMESTAMP('2025-01-01T01:00:00Z'), 11.0 UNION ALL SELECT TIMESTAMP('2025-01-01T02:00:00Z'), 12.0 UNION ALL SELECT TIMESTAMP('2025-01-01T03:00:00Z'), 13.0"}`, 1224 | wantStatusCode: http.StatusOK, 1225 | wantInError: "", 1226 | wantResult: `"Query executed successfully and returned no content."`, 1227 | }, 1228 | { 1229 | name: "Forecast from TEMP TABLE should succeed", 1230 | toolName: "my-forecast-tool-protected", 1231 | requestBody: `{"history_data": "SELECT * FROM forecast_temp_table", "timestamp_col": "ts", "data_col": "data", "horizon": 1}`, 1232 | wantStatusCode: http.StatusOK, 1233 | wantInError: "", 1234 | wantResult: `"forecast_timestamp"`, 1235 | }, 1236 | { 1237 | name: "CREATE TEMP TABLE for contribution analysis should succeed", 1238 | toolName: "my-exec-sql-tool", 1239 | requestBody: `{"sql": "CREATE TEMP TABLE contribution_temp_table (dim1 STRING, is_test BOOL, metric FLOAT64) AS SELECT 'a' as dim1, true as is_test, 100.0 as metric UNION ALL SELECT 'b', false, 120.0"}`, 1240 | wantStatusCode: http.StatusOK, 1241 | wantInError: "", 1242 | wantResult: `"Query executed successfully and returned no content."`, 1243 | }, 1244 | { 1245 | name: "Analyze contribution from TEMP TABLE should succeed", 1246 | toolName: "my-analyze-contribution-tool-protected", 1247 | requestBody: `{"input_data": "SELECT * FROM contribution_temp_table", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1"]}`, 1248 | wantStatusCode: http.StatusOK, 1249 | wantInError: "", 1250 | wantResult: `"relative_difference"`, 1251 | }, 1252 | } 1253 | for _, tc := range testCases { 1254 | t.Run(tc.name, func(t *testing.T) { 1255 | api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) 1256 | resp, result := runInvokeRequest(t, api, tc.requestBody, nil) 1257 | defer resp.Body.Close() 1258 | 1259 | if resp.StatusCode != tc.wantStatusCode { 1260 | bodyBytes, _ := io.ReadAll(resp.Body) 1261 | t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) 1262 | } 1263 | 1264 | if tc.wantInError != "" { 1265 | errStr, ok := result["error"].(string) 1266 | if !ok { 1267 | t.Fatalf("expected 'error' field in response, got %v", result) 1268 | } 1269 | if !strings.Contains(errStr, tc.wantInError) { 1270 | t.Fatalf("expected error message to contain %q, but got %q", tc.wantInError, errStr) 1271 | } 1272 | } 1273 | 1274 | if tc.wantResult != "" { 1275 | resStr, ok := result["result"].(string) 1276 | if !ok { 1277 | t.Fatalf("expected 'result' field in response, got %v", result) 1278 | } 1279 | if !strings.Contains(resStr, tc.wantResult) { 1280 | t.Fatalf("expected %q to contain %q, but it did not", resStr, tc.wantResult) 1281 | } 1282 | } 1283 | }) 1284 | } 1285 | } 1286 | 1287 | func runBigQueryExecuteSqlToolInvokeDryRunTest(t *testing.T, datasetName string) { 1288 | // Get ID token 1289 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 1290 | if err != nil { 1291 | t.Fatalf("error getting Google ID token: %s", err) 1292 | } 1293 | 1294 | newTableName := fmt.Sprintf("%s.new_dry_run_table_%s", datasetName, strings.ReplaceAll(uuid.New().String(), "-", "")) 1295 | 1296 | // Test tool invoke endpoint 1297 | invokeTcs := []struct { 1298 | name string 1299 | api string 1300 | requestHeader map[string]string 1301 | requestBody io.Reader 1302 | want string 1303 | isErr bool 1304 | }{ 1305 | { 1306 | name: "invoke my-exec-sql-tool with dryRun", 1307 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 1308 | requestHeader: map[string]string{}, 1309 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1", "dry_run": true}`)), 1310 | want: `\"statementType\": \"SELECT\"`, 1311 | isErr: false, 1312 | }, 1313 | { 1314 | name: "invoke my-exec-sql-tool with dryRun create table", 1315 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 1316 | requestHeader: map[string]string{}, 1317 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"CREATE TABLE %s (id INT64, name STRING)", "dry_run": true}`, newTableName))), 1318 | want: `\"statementType\": \"CREATE_TABLE\"`, 1319 | isErr: false, 1320 | }, 1321 | { 1322 | name: "invoke my-exec-sql-tool with dryRun execute immediate", 1323 | api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", 1324 | requestHeader: map[string]string{}, 1325 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"EXECUTE IMMEDIATE \"CREATE TABLE %s (id INT64, name STRING)\"", "dry_run": true}`, newTableName))), 1326 | want: `\"statementType\": \"SCRIPT\"`, 1327 | isErr: false, 1328 | }, 1329 | { 1330 | name: "Invoke my-auth-exec-sql-tool with dryRun and auth token", 1331 | api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", 1332 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 1333 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1", "dry_run": true}`)), 1334 | isErr: false, 1335 | want: `\"statementType\": \"SELECT\"`, 1336 | }, 1337 | { 1338 | name: "Invoke my-auth-exec-sql-tool with dryRun and invalid auth token", 1339 | api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", 1340 | requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, 1341 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1","dry_run": true}`)), 1342 | isErr: true, 1343 | }, 1344 | { 1345 | name: "Invoke my-auth-exec-sql-tool with dryRun and without auth token", 1346 | api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", 1347 | requestHeader: map[string]string{}, 1348 | requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1", "dry_run": true}`)), 1349 | isErr: true, 1350 | }, 1351 | } 1352 | for _, tc := range invokeTcs { 1353 | t.Run(tc.name, func(t *testing.T) { 1354 | // Send Tool invocation request 1355 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 1356 | if err != nil { 1357 | t.Fatalf("unable to create request: %s", err) 1358 | } 1359 | req.Header.Add("Content-type", "application/json") 1360 | for k, v := range tc.requestHeader { 1361 | req.Header.Add(k, v) 1362 | } 1363 | resp, err := http.DefaultClient.Do(req) 1364 | if err != nil { 1365 | t.Fatalf("unable to send request: %s", err) 1366 | } 1367 | defer resp.Body.Close() 1368 | 1369 | if resp.StatusCode != http.StatusOK { 1370 | if tc.isErr { 1371 | return 1372 | } 1373 | bodyBytes, _ := io.ReadAll(resp.Body) 1374 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 1375 | } 1376 | 1377 | // Check response body 1378 | var body map[string]interface{} 1379 | err = json.NewDecoder(resp.Body).Decode(&body) 1380 | if err != nil { 1381 | t.Fatalf("error parsing response body") 1382 | } 1383 | 1384 | got, ok := body["result"].(string) 1385 | if !ok { 1386 | t.Fatalf("unable to find result in response body") 1387 | } 1388 | 1389 | if !strings.Contains(got, tc.want) { 1390 | t.Fatalf("expected %q to contain %q, but it did not", got, tc.want) 1391 | } 1392 | }) 1393 | } 1394 | } 1395 | 1396 | func runBigQueryForecastToolInvokeTest(t *testing.T, tableName string) { 1397 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 1398 | if err != nil { 1399 | t.Fatalf("error getting Google ID token: %s", err) 1400 | } 1401 | 1402 | // Get access token 1403 | accessToken, err := sources.GetIAMAccessToken(t.Context()) 1404 | if err != nil { 1405 | t.Fatalf("error getting access token from ADC: %s", err) 1406 | } 1407 | accessToken = "Bearer " + accessToken 1408 | 1409 | historyDataTable := strings.ReplaceAll(tableName, "`", "") 1410 | historyDataQuery := fmt.Sprintf("SELECT ts, data, id FROM %s", tableName) 1411 | 1412 | invokeTcs := []struct { 1413 | name string 1414 | api string 1415 | requestHeader map[string]string 1416 | requestBody io.Reader 1417 | want string 1418 | isErr bool 1419 | }{ 1420 | { 1421 | name: "invoke my-forecast-tool without required params", 1422 | api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke", 1423 | requestHeader: map[string]string{}, 1424 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s"}`, historyDataTable))), 1425 | isErr: true, 1426 | }, 1427 | { 1428 | name: "invoke my-forecast-tool with table", 1429 | api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke", 1430 | requestHeader: map[string]string{}, 1431 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))), 1432 | want: `"forecast_timestamp"`, 1433 | isErr: false, 1434 | }, 1435 | { 1436 | name: "invoke my-forecast-tool with query and horizon", 1437 | api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke", 1438 | requestHeader: map[string]string{}, 1439 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data", "horizon": 5}`, historyDataQuery))), 1440 | want: `"forecast_timestamp"`, 1441 | isErr: false, 1442 | }, 1443 | { 1444 | name: "invoke my-forecast-tool with id_cols", 1445 | api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke", 1446 | requestHeader: map[string]string{}, 1447 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data", "id_cols": ["id"]}`, historyDataTable))), 1448 | want: `"id"`, 1449 | isErr: false, 1450 | }, 1451 | { 1452 | name: "invoke my-auth-forecast-tool with auth token", 1453 | api: "http://127.0.0.1:5000/api/tool/my-auth-forecast-tool/invoke", 1454 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 1455 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))), 1456 | want: `"forecast_timestamp"`, 1457 | isErr: false, 1458 | }, 1459 | { 1460 | name: "invoke my-auth-forecast-tool with invalid auth token", 1461 | api: "http://127.0.0.1:5000/api/tool/my-auth-forecast-tool/invoke", 1462 | requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, 1463 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))), 1464 | isErr: true, 1465 | }, 1466 | { 1467 | name: "Invoke my-client-auth-forecast-tool with auth token", 1468 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-forecast-tool/invoke", 1469 | requestHeader: map[string]string{"Authorization": accessToken}, 1470 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))), 1471 | want: `"forecast_timestamp"`, 1472 | isErr: false, 1473 | }, 1474 | { 1475 | name: "Invoke my-client-auth-forecast-tool without auth token", 1476 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-forecast-tool/invoke", 1477 | requestHeader: map[string]string{}, 1478 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))), 1479 | isErr: true, 1480 | }, 1481 | { 1482 | 1483 | name: "Invoke my-client-auth-forecast-tool with invalid auth token", 1484 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-forecast-tool/invoke", 1485 | requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, 1486 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))), 1487 | isErr: true, 1488 | }, 1489 | } 1490 | for _, tc := range invokeTcs { 1491 | t.Run(tc.name, func(t *testing.T) { 1492 | // Send Tool invocation request 1493 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 1494 | if err != nil { 1495 | t.Fatalf("unable to create request: %s", err) 1496 | } 1497 | req.Header.Add("Content-type", "application/json") 1498 | for k, v := range tc.requestHeader { 1499 | req.Header.Add(k, v) 1500 | } 1501 | resp, err := http.DefaultClient.Do(req) 1502 | if err != nil { 1503 | t.Fatalf("unable to send request: %s", err) 1504 | } 1505 | defer resp.Body.Close() 1506 | 1507 | if resp.StatusCode != http.StatusOK { 1508 | if tc.isErr { 1509 | return 1510 | } 1511 | bodyBytes, _ := io.ReadAll(resp.Body) 1512 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 1513 | } 1514 | 1515 | // Check response body 1516 | var body map[string]interface{} 1517 | err = json.NewDecoder(resp.Body).Decode(&body) 1518 | if err != nil { 1519 | t.Fatalf("error parsing response body") 1520 | } 1521 | 1522 | got, ok := body["result"].(string) 1523 | if !ok { 1524 | t.Fatalf("unable to find result in response body") 1525 | } 1526 | 1527 | if !strings.Contains(got, tc.want) { 1528 | t.Fatalf("expected %q to contain %q, but it did not", got, tc.want) 1529 | } 1530 | }) 1531 | } 1532 | } 1533 | 1534 | func runBigQueryAnalyzeContributionToolInvokeTest(t *testing.T, tableName string) { 1535 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 1536 | if err != nil { 1537 | t.Fatalf("error getting Google ID token: %s", err) 1538 | } 1539 | 1540 | // Get access token 1541 | accessToken, err := sources.GetIAMAccessToken(t.Context()) 1542 | if err != nil { 1543 | t.Fatalf("error getting access token from ADC: %s", err) 1544 | } 1545 | accessToken = "Bearer " + accessToken 1546 | 1547 | dataTable := strings.ReplaceAll(tableName, "`", "") 1548 | 1549 | invokeTcs := []struct { 1550 | name string 1551 | api string 1552 | requestHeader map[string]string 1553 | requestBody io.Reader 1554 | want string 1555 | isErr bool 1556 | }{ 1557 | { 1558 | name: "invoke my-analyze-contribution-tool without required params", 1559 | api: "http://127.0.0.1:5000/api/tool/my-analyze-contribution-tool/invoke", 1560 | requestHeader: map[string]string{}, 1561 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s"}`, dataTable))), 1562 | isErr: true, 1563 | }, 1564 | { 1565 | name: "invoke my-analyze-contribution-tool with table", 1566 | api: "http://127.0.0.1:5000/api/tool/my-analyze-contribution-tool/invoke", 1567 | requestHeader: map[string]string{}, 1568 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))), 1569 | want: `"relative_difference"`, 1570 | isErr: false, 1571 | }, 1572 | { 1573 | name: "invoke my-auth-analyze-contribution-tool with auth token", 1574 | api: "http://127.0.0.1:5000/api/tool/my-auth-analyze-contribution-tool/invoke", 1575 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 1576 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))), 1577 | want: `"relative_difference"`, 1578 | isErr: false, 1579 | }, 1580 | { 1581 | name: "invoke my-auth-analyze-contribution-tool with invalid auth token", 1582 | api: "http://127.0.0.1:5000/api/tool/my-auth-analyze-contribution-tool/invoke", 1583 | requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, 1584 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))), 1585 | isErr: true, 1586 | }, 1587 | { 1588 | name: "Invoke my-client-auth-analyze-contribution-tool with auth token", 1589 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-analyze-contribution-tool/invoke", 1590 | requestHeader: map[string]string{"Authorization": accessToken}, 1591 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))), 1592 | want: `"relative_difference"`, 1593 | isErr: false, 1594 | }, 1595 | { 1596 | name: "Invoke my-client-auth-analyze-contribution-tool without auth token", 1597 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-analyze-contribution-tool/invoke", 1598 | requestHeader: map[string]string{}, 1599 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))), 1600 | isErr: true, 1601 | }, 1602 | { 1603 | 1604 | name: "Invoke my-client-auth-analyze-contribution-tool with invalid auth token", 1605 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-analyze-contribution-tool/invoke", 1606 | requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, 1607 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))), 1608 | isErr: true, 1609 | }, 1610 | } 1611 | for _, tc := range invokeTcs { 1612 | t.Run(tc.name, func(t *testing.T) { 1613 | // Send Tool invocation request 1614 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 1615 | if err != nil { 1616 | t.Fatalf("unable to create request: %s", err) 1617 | } 1618 | req.Header.Add("Content-type", "application/json") 1619 | for k, v := range tc.requestHeader { 1620 | req.Header.Add(k, v) 1621 | } 1622 | resp, err := http.DefaultClient.Do(req) 1623 | if err != nil { 1624 | t.Fatalf("unable to send request: %s", err) 1625 | } 1626 | defer resp.Body.Close() 1627 | 1628 | if resp.StatusCode != http.StatusOK { 1629 | if tc.isErr { 1630 | return 1631 | } 1632 | bodyBytes, _ := io.ReadAll(resp.Body) 1633 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 1634 | } 1635 | 1636 | // Check response body 1637 | var body map[string]interface{} 1638 | err = json.NewDecoder(resp.Body).Decode(&body) 1639 | if err != nil { 1640 | t.Fatalf("error parsing response body") 1641 | } 1642 | 1643 | got, ok := body["result"].(string) 1644 | if !ok { 1645 | t.Fatalf("unable to find result in response body") 1646 | } 1647 | 1648 | if !strings.Contains(got, tc.want) { 1649 | t.Fatalf("expected %q to contain %q, but it did not", got, tc.want) 1650 | } 1651 | }) 1652 | } 1653 | } 1654 | 1655 | func runBigQueryDataTypeTests(t *testing.T) { 1656 | // Test tool invoke endpoint 1657 | invokeTcs := []struct { 1658 | name string 1659 | api string 1660 | requestHeader map[string]string 1661 | requestBody io.Reader 1662 | want string 1663 | isErr bool 1664 | }{ 1665 | { 1666 | name: "invoke my-scalar-datatype-tool with values", 1667 | api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke", 1668 | requestHeader: map[string]string{}, 1669 | requestBody: bytes.NewBuffer([]byte(`{"int_val": 123, "string_val": "hello", "float_val": 3.14, "bool_val": true}`)), 1670 | want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"}]`, 1671 | isErr: false, 1672 | }, 1673 | { 1674 | name: "invoke my-scalar-datatype-tool with missing params", 1675 | api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke", 1676 | requestHeader: map[string]string{}, 1677 | requestBody: bytes.NewBuffer([]byte(`{"int_val": 123}`)), 1678 | isErr: true, 1679 | }, 1680 | { 1681 | name: "invoke my-array-datatype-tool", 1682 | api: "http://127.0.0.1:5000/api/tool/my-array-datatype-tool/invoke", 1683 | requestHeader: map[string]string{}, 1684 | requestBody: bytes.NewBuffer([]byte(`{"int_array": [123, 789], "string_array": ["hello", "test"], "float_array": [3.14, 100.1], "bool_array": [true]}`)), 1685 | want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"},{"bool_val":true,"float_val":100.1,"id":3,"int_val":789,"string_val":"test"}]`, 1686 | isErr: false, 1687 | }, 1688 | } 1689 | for _, tc := range invokeTcs { 1690 | t.Run(tc.name, func(t *testing.T) { 1691 | // Send Tool invocation request 1692 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 1693 | if err != nil { 1694 | t.Fatalf("unable to create request: %s", err) 1695 | } 1696 | req.Header.Add("Content-type", "application/json") 1697 | for k, v := range tc.requestHeader { 1698 | req.Header.Add(k, v) 1699 | } 1700 | resp, err := http.DefaultClient.Do(req) 1701 | if err != nil { 1702 | t.Fatalf("unable to send request: %s", err) 1703 | } 1704 | defer resp.Body.Close() 1705 | 1706 | if resp.StatusCode != http.StatusOK { 1707 | if tc.isErr { 1708 | return 1709 | } 1710 | bodyBytes, _ := io.ReadAll(resp.Body) 1711 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 1712 | } 1713 | 1714 | // Check response body 1715 | var body map[string]interface{} 1716 | err = json.NewDecoder(resp.Body).Decode(&body) 1717 | if err != nil { 1718 | t.Fatalf("error parsing response body") 1719 | } 1720 | 1721 | got, ok := body["result"].(string) 1722 | if !ok { 1723 | t.Fatalf("unable to find result in response body") 1724 | } 1725 | 1726 | if got != tc.want { 1727 | t.Fatalf("unexpected value: got %q, want %q", got, tc.want) 1728 | } 1729 | }) 1730 | } 1731 | } 1732 | 1733 | func runBigQueryListDatasetToolInvokeTest(t *testing.T, datasetWant string) { 1734 | // Get ID token 1735 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 1736 | if err != nil { 1737 | t.Fatalf("error getting Google ID token: %s", err) 1738 | } 1739 | 1740 | // Get access token 1741 | accessToken, err := sources.GetIAMAccessToken(t.Context()) 1742 | if err != nil { 1743 | t.Fatalf("error getting access token from ADC: %s", err) 1744 | } 1745 | accessToken = "Bearer " + accessToken 1746 | 1747 | // Test tool invoke endpoint 1748 | invokeTcs := []struct { 1749 | name string 1750 | api string 1751 | requestHeader map[string]string 1752 | requestBody io.Reader 1753 | want string 1754 | isErr bool 1755 | }{ 1756 | { 1757 | name: "invoke my-list-dataset-ids-tool", 1758 | api: "http://127.0.0.1:5000/api/tool/my-list-dataset-ids-tool/invoke", 1759 | requestHeader: map[string]string{}, 1760 | requestBody: bytes.NewBuffer([]byte(`{}`)), 1761 | isErr: false, 1762 | want: datasetWant, 1763 | }, 1764 | { 1765 | name: "invoke my-list-dataset-ids-tool with project", 1766 | api: "http://127.0.0.1:5000/api/tool/my-auth-list-dataset-ids-tool/invoke", 1767 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 1768 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\"}", BigqueryProject))), 1769 | isErr: false, 1770 | want: datasetWant, 1771 | }, 1772 | { 1773 | name: "invoke my-list-dataset-ids-tool with non-existent project", 1774 | api: "http://127.0.0.1:5000/api/tool/my-auth-list-dataset-ids-tool/invoke", 1775 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 1776 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\"}", BigqueryProject, uuid.NewString()))), 1777 | isErr: true, 1778 | }, 1779 | { 1780 | name: "invoke my-auth-list-dataset-ids-tool", 1781 | api: "http://127.0.0.1:5000/api/tool/my-auth-list-dataset-ids-tool/invoke", 1782 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 1783 | requestBody: bytes.NewBuffer([]byte(`{}`)), 1784 | isErr: false, 1785 | want: datasetWant, 1786 | }, 1787 | { 1788 | name: "Invoke my-client-auth-list-dataset-ids-tool with auth token", 1789 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-dataset-ids-tool/invoke", 1790 | requestHeader: map[string]string{"Authorization": accessToken}, 1791 | requestBody: bytes.NewBuffer([]byte(`{}`)), 1792 | isErr: false, 1793 | want: datasetWant, 1794 | }, 1795 | { 1796 | name: "Invoke my-client-auth-list-dataset-ids-tool without auth token", 1797 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-dataset-ids-tool/invoke", 1798 | requestHeader: map[string]string{}, 1799 | requestBody: bytes.NewBuffer([]byte(`{}`)), 1800 | isErr: true, 1801 | }, 1802 | { 1803 | 1804 | name: "Invoke my-client-auth-list-dataset-ids-tool with invalid auth token", 1805 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-dataset-ids-tool/invoke", 1806 | requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, 1807 | requestBody: bytes.NewBuffer([]byte(`{}`)), 1808 | isErr: true, 1809 | }, 1810 | } 1811 | for _, tc := range invokeTcs { 1812 | t.Run(tc.name, func(t *testing.T) { 1813 | // Send Tool invocation request 1814 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 1815 | if err != nil { 1816 | t.Fatalf("unable to create request: %s", err) 1817 | } 1818 | req.Header.Add("Content-type", "application/json") 1819 | for k, v := range tc.requestHeader { 1820 | req.Header.Add(k, v) 1821 | } 1822 | resp, err := http.DefaultClient.Do(req) 1823 | if err != nil { 1824 | t.Fatalf("unable to send request: %s", err) 1825 | } 1826 | defer resp.Body.Close() 1827 | 1828 | if resp.StatusCode != http.StatusOK { 1829 | if tc.isErr { 1830 | return 1831 | } 1832 | bodyBytes, _ := io.ReadAll(resp.Body) 1833 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 1834 | } 1835 | 1836 | // Check response body 1837 | var body map[string]interface{} 1838 | err = json.NewDecoder(resp.Body).Decode(&body) 1839 | if err != nil { 1840 | t.Fatalf("error parsing response body") 1841 | } 1842 | 1843 | got, ok := body["result"].(string) 1844 | if !ok { 1845 | t.Fatalf("unable to find result in response body") 1846 | } 1847 | 1848 | if !strings.Contains(got, tc.want) { 1849 | t.Fatalf("expected %q to contain %q, but it did not", got, tc.want) 1850 | } 1851 | }) 1852 | } 1853 | } 1854 | 1855 | func runBigQueryGetDatasetInfoToolInvokeTest(t *testing.T, datasetName, datasetInfoWant string) { 1856 | // Get ID token 1857 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 1858 | if err != nil { 1859 | t.Fatalf("error getting Google ID token: %s", err) 1860 | } 1861 | 1862 | // Get access token 1863 | accessToken, err := sources.GetIAMAccessToken(t.Context()) 1864 | if err != nil { 1865 | t.Fatalf("error getting access token from ADC: %s", err) 1866 | } 1867 | accessToken = "Bearer " + accessToken 1868 | 1869 | // Test tool invoke endpoint 1870 | invokeTcs := []struct { 1871 | name string 1872 | api string 1873 | requestHeader map[string]string 1874 | requestBody io.Reader 1875 | want string 1876 | isErr bool 1877 | }{ 1878 | { 1879 | name: "invoke my-get-dataset-info-tool without body", 1880 | api: "http://127.0.0.1:5000/api/tool/my-get-dataset-info-tool/invoke", 1881 | requestHeader: map[string]string{}, 1882 | requestBody: bytes.NewBuffer([]byte(`{}`)), 1883 | isErr: true, 1884 | }, 1885 | { 1886 | name: "invoke my-get-dataset-info-tool", 1887 | api: "http://127.0.0.1:5000/api/tool/my-get-dataset-info-tool/invoke", 1888 | requestHeader: map[string]string{}, 1889 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 1890 | want: datasetInfoWant, 1891 | isErr: false, 1892 | }, 1893 | { 1894 | name: "Invoke my-auth-get-dataset-info-tool with correct project", 1895 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-info-tool/invoke", 1896 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 1897 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\", \"dataset\":\"%s\"}", BigqueryProject, datasetName))), 1898 | want: datasetInfoWant, 1899 | isErr: false, 1900 | }, 1901 | { 1902 | name: "Invoke my-auth-get-dataset-info-tool with non-existent project", 1903 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-info-tool/invoke", 1904 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 1905 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\", \"dataset\":\"%s\"}", BigqueryProject, uuid.NewString(), datasetName))), 1906 | isErr: true, 1907 | }, 1908 | { 1909 | name: "invoke my-auth-get-dataset-info-tool without body", 1910 | api: "http://127.0.0.1:5000/api/tool/my-get-dataset-info-tool/invoke", 1911 | requestHeader: map[string]string{}, 1912 | requestBody: bytes.NewBuffer([]byte(`{}`)), 1913 | isErr: true, 1914 | }, 1915 | { 1916 | name: "Invoke my-auth-get-dataset-info-tool with auth token", 1917 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-info-tool/invoke", 1918 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 1919 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 1920 | want: datasetInfoWant, 1921 | isErr: false, 1922 | }, 1923 | { 1924 | name: "Invoke my-auth-get-dataset-info-tool with invalid auth token", 1925 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-info-tool/invoke", 1926 | requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, 1927 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 1928 | isErr: true, 1929 | }, 1930 | { 1931 | name: "Invoke my-auth-get-dataset-info-tool without auth token", 1932 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-info-tool/invoke", 1933 | requestHeader: map[string]string{}, 1934 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 1935 | isErr: true, 1936 | }, 1937 | { 1938 | name: "Invoke my-client-auth-get-dataset-info-tool with auth token", 1939 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-dataset-info-tool/invoke", 1940 | requestHeader: map[string]string{"Authorization": accessToken}, 1941 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 1942 | want: datasetInfoWant, 1943 | isErr: false, 1944 | }, 1945 | { 1946 | name: "Invoke my-client-auth-get-dataset-info-tool without auth token", 1947 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-dataset-info-tool/invoke", 1948 | requestHeader: map[string]string{}, 1949 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 1950 | isErr: true, 1951 | }, 1952 | { 1953 | 1954 | name: "Invoke my-client-auth-get-dataset-info-tool with invalid auth token", 1955 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-dataset-info-tool/invoke", 1956 | requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, 1957 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 1958 | isErr: true, 1959 | }, 1960 | } 1961 | for _, tc := range invokeTcs { 1962 | t.Run(tc.name, func(t *testing.T) { 1963 | // Send Tool invocation request 1964 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 1965 | if err != nil { 1966 | t.Fatalf("unable to create request: %s", err) 1967 | } 1968 | req.Header.Add("Content-type", "application/json") 1969 | for k, v := range tc.requestHeader { 1970 | req.Header.Add(k, v) 1971 | } 1972 | resp, err := http.DefaultClient.Do(req) 1973 | if err != nil { 1974 | t.Fatalf("unable to send request: %s", err) 1975 | } 1976 | defer resp.Body.Close() 1977 | 1978 | if resp.StatusCode != http.StatusOK { 1979 | if tc.isErr { 1980 | return 1981 | } 1982 | bodyBytes, _ := io.ReadAll(resp.Body) 1983 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 1984 | } 1985 | 1986 | // Check response body 1987 | var body map[string]interface{} 1988 | err = json.NewDecoder(resp.Body).Decode(&body) 1989 | if err != nil { 1990 | t.Fatalf("error parsing response body") 1991 | } 1992 | 1993 | got, ok := body["result"].(string) 1994 | if !ok { 1995 | t.Fatalf("unable to find result in response body") 1996 | } 1997 | 1998 | if !strings.Contains(got, tc.want) { 1999 | t.Fatalf("expected %q to contain %q, but it did not", got, tc.want) 2000 | } 2001 | }) 2002 | } 2003 | } 2004 | 2005 | func runBigQueryListTableIdsToolInvokeTest(t *testing.T, datasetName, tablename_want string) { 2006 | // Get ID token 2007 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 2008 | if err != nil { 2009 | t.Fatalf("error getting Google ID token: %s", err) 2010 | } 2011 | 2012 | // Get access token 2013 | accessToken, err := sources.GetIAMAccessToken(t.Context()) 2014 | if err != nil { 2015 | t.Fatalf("error getting access token from ADC: %s", err) 2016 | } 2017 | accessToken = "Bearer " + accessToken 2018 | 2019 | // Test tool invoke endpoint 2020 | invokeTcs := []struct { 2021 | name string 2022 | api string 2023 | requestHeader map[string]string 2024 | requestBody io.Reader 2025 | want string 2026 | isErr bool 2027 | }{ 2028 | { 2029 | name: "invoke my-list-table-ids-tool without body", 2030 | api: "http://127.0.0.1:5000/api/tool/my-list-table-ids-tool/invoke", 2031 | requestHeader: map[string]string{}, 2032 | requestBody: bytes.NewBuffer([]byte(`{}`)), 2033 | isErr: true, 2034 | }, 2035 | { 2036 | name: "invoke my-list-table-ids-tool", 2037 | api: "http://127.0.0.1:5000/api/tool/my-list-table-ids-tool/invoke", 2038 | requestHeader: map[string]string{}, 2039 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 2040 | want: tablename_want, 2041 | isErr: false, 2042 | }, 2043 | { 2044 | name: "invoke my-list-table-ids-tool without body", 2045 | api: "http://127.0.0.1:5000/api/tool/my-list-table-ids-tool/invoke", 2046 | requestHeader: map[string]string{}, 2047 | requestBody: bytes.NewBuffer([]byte(`{}`)), 2048 | isErr: true, 2049 | }, 2050 | { 2051 | name: "Invoke my-auth-list-table-ids-tool with auth token", 2052 | api: "http://127.0.0.1:5000/api/tool/my-auth-list-table-ids-tool/invoke", 2053 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2054 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 2055 | want: tablename_want, 2056 | isErr: false, 2057 | }, 2058 | { 2059 | name: "Invoke my-auth-list-table-ids-tool with correct project", 2060 | api: "http://127.0.0.1:5000/api/tool/my-auth-list-table-ids-tool/invoke", 2061 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2062 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\", \"dataset\":\"%s\"}", BigqueryProject, datasetName))), 2063 | want: tablename_want, 2064 | isErr: false, 2065 | }, 2066 | { 2067 | name: "Invoke my-auth-list-table-ids-tool with non-existent project", 2068 | api: "http://127.0.0.1:5000/api/tool/my-auth-list-table-ids-tool/invoke", 2069 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2070 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\", \"dataset\":\"%s\"}", BigqueryProject, uuid.NewString(), datasetName))), 2071 | isErr: true, 2072 | }, 2073 | { 2074 | name: "Invoke my-auth-list-table-ids-tool with invalid auth token", 2075 | api: "http://127.0.0.1:5000/api/tool/my-auth-list-table-ids-tool/invoke", 2076 | requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, 2077 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 2078 | isErr: true, 2079 | }, 2080 | { 2081 | name: "Invoke my-auth-list-table-ids-tool without auth token", 2082 | api: "http://127.0.0.1:5000/api/tool/my-auth-list-table-ids-tool/invoke", 2083 | requestHeader: map[string]string{}, 2084 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 2085 | isErr: true, 2086 | }, 2087 | { 2088 | name: "Invoke my-client-auth-list-table-ids-tool with auth token", 2089 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-table-ids-tool/invoke", 2090 | requestHeader: map[string]string{"Authorization": accessToken}, 2091 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 2092 | want: tablename_want, 2093 | isErr: false, 2094 | }, 2095 | { 2096 | name: "Invoke my-client-auth-list-table-ids-tool without auth token", 2097 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-table-ids-tool/invoke", 2098 | requestHeader: map[string]string{}, 2099 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 2100 | isErr: true, 2101 | }, 2102 | { 2103 | 2104 | name: "Invoke my-client-auth-list-table-ids-tool with invalid auth token", 2105 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-table-ids-tool/invoke", 2106 | requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, 2107 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\"}", datasetName))), 2108 | isErr: true, 2109 | }, 2110 | } 2111 | for _, tc := range invokeTcs { 2112 | t.Run(tc.name, func(t *testing.T) { 2113 | // Send Tool invocation request 2114 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 2115 | if err != nil { 2116 | t.Fatalf("unable to create request: %s", err) 2117 | } 2118 | req.Header.Add("Content-type", "application/json") 2119 | for k, v := range tc.requestHeader { 2120 | req.Header.Add(k, v) 2121 | } 2122 | resp, err := http.DefaultClient.Do(req) 2123 | if err != nil { 2124 | t.Fatalf("unable to send request: %s", err) 2125 | } 2126 | defer resp.Body.Close() 2127 | 2128 | if resp.StatusCode != http.StatusOK { 2129 | if tc.isErr { 2130 | return 2131 | } 2132 | bodyBytes, _ := io.ReadAll(resp.Body) 2133 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 2134 | } 2135 | 2136 | // Check response body 2137 | var body map[string]interface{} 2138 | err = json.NewDecoder(resp.Body).Decode(&body) 2139 | if err != nil { 2140 | t.Fatalf("error parsing response body") 2141 | } 2142 | 2143 | got, ok := body["result"].(string) 2144 | if !ok { 2145 | t.Fatalf("unable to find result in response body") 2146 | } 2147 | 2148 | if !strings.Contains(got, tc.want) { 2149 | t.Fatalf("expected %q to contain %q, but it did not", got, tc.want) 2150 | } 2151 | }) 2152 | } 2153 | } 2154 | 2155 | func runBigQueryGetTableInfoToolInvokeTest(t *testing.T, datasetName, tableName, tableInfoWant string) { 2156 | // Get ID token 2157 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 2158 | if err != nil { 2159 | t.Fatalf("error getting Google ID token: %s", err) 2160 | } 2161 | 2162 | // Get access token 2163 | accessToken, err := sources.GetIAMAccessToken(t.Context()) 2164 | if err != nil { 2165 | t.Fatalf("error getting access token from ADC: %s", err) 2166 | } 2167 | accessToken = "Bearer " + accessToken 2168 | 2169 | // Test tool invoke endpoint 2170 | invokeTcs := []struct { 2171 | name string 2172 | api string 2173 | requestHeader map[string]string 2174 | requestBody io.Reader 2175 | want string 2176 | isErr bool 2177 | }{ 2178 | { 2179 | name: "invoke my-get-table-info-tool without body", 2180 | api: "http://127.0.0.1:5000/api/tool/my-get-table-info-tool/invoke", 2181 | requestHeader: map[string]string{}, 2182 | requestBody: bytes.NewBuffer([]byte(`{}`)), 2183 | isErr: true, 2184 | }, 2185 | { 2186 | name: "invoke my-get-table-info-tool", 2187 | api: "http://127.0.0.1:5000/api/tool/my-get-table-info-tool/invoke", 2188 | requestHeader: map[string]string{}, 2189 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\", \"table\":\"%s\"}", datasetName, tableName))), 2190 | want: tableInfoWant, 2191 | isErr: false, 2192 | }, 2193 | { 2194 | name: "invoke my-auth-get-table-info-tool without body", 2195 | api: "http://127.0.0.1:5000/api/tool/my-get-table-info-tool/invoke", 2196 | requestHeader: map[string]string{}, 2197 | requestBody: bytes.NewBuffer([]byte(`{}`)), 2198 | isErr: true, 2199 | }, 2200 | { 2201 | name: "Invoke my-auth-get-table-info-tool with auth token", 2202 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-table-info-tool/invoke", 2203 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2204 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\", \"table\":\"%s\"}", datasetName, tableName))), 2205 | want: tableInfoWant, 2206 | isErr: false, 2207 | }, 2208 | { 2209 | name: "Invoke my-auth-get-table-info-tool with correct project", 2210 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-table-info-tool/invoke", 2211 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2212 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\", \"dataset\":\"%s\", \"table\":\"%s\"}", BigqueryProject, datasetName, tableName))), 2213 | want: tableInfoWant, 2214 | isErr: false, 2215 | }, 2216 | { 2217 | name: "Invoke my-auth-get-table-info-tool with non-existent project", 2218 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-table-info-tool/invoke", 2219 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2220 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\", \"dataset\":\"%s\", \"table\":\"%s\"}", BigqueryProject, uuid.NewString(), datasetName, tableName))), 2221 | isErr: true, 2222 | }, 2223 | { 2224 | name: "Invoke my-auth-get-table-info-tool with invalid auth token", 2225 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-table-info-tool/invoke", 2226 | requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, 2227 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\", \"table\":\"%s\"}", datasetName, tableName))), 2228 | isErr: true, 2229 | }, 2230 | { 2231 | name: "Invoke my-auth-get-table-info-tool without auth token", 2232 | api: "http://127.0.0.1:5000/api/tool/my-auth-get-table-info-tool/invoke", 2233 | requestHeader: map[string]string{}, 2234 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\", \"table\":\"%s\"}", datasetName, tableName))), 2235 | isErr: true, 2236 | }, 2237 | { 2238 | name: "Invoke my-client-auth-get-table-info-tool with auth token", 2239 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-table-info-tool/invoke", 2240 | requestHeader: map[string]string{"Authorization": accessToken}, 2241 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\", \"table\":\"%s\"}", datasetName, tableName))), 2242 | want: tableInfoWant, 2243 | isErr: false, 2244 | }, 2245 | { 2246 | name: "Invoke my-client-auth-get-table-info-tool without auth token", 2247 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-table-info-tool/invoke", 2248 | requestHeader: map[string]string{}, 2249 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\", \"table\":\"%s\"}", datasetName, tableName))), 2250 | isErr: true, 2251 | }, 2252 | { 2253 | 2254 | name: "Invoke my-client-auth-get-table-info-tool with invalid auth token", 2255 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-table-info-tool/invoke", 2256 | requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, 2257 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"dataset\":\"%s\", \"table\":\"%s\"}", datasetName, tableName))), 2258 | isErr: true, 2259 | }, 2260 | } 2261 | for _, tc := range invokeTcs { 2262 | t.Run(tc.name, func(t *testing.T) { 2263 | // Send Tool invocation request 2264 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 2265 | if err != nil { 2266 | t.Fatalf("unable to create request: %s", err) 2267 | } 2268 | req.Header.Add("Content-type", "application/json") 2269 | for k, v := range tc.requestHeader { 2270 | req.Header.Add(k, v) 2271 | } 2272 | resp, err := http.DefaultClient.Do(req) 2273 | if err != nil { 2274 | t.Fatalf("unable to send request: %s", err) 2275 | } 2276 | defer resp.Body.Close() 2277 | 2278 | if resp.StatusCode != http.StatusOK { 2279 | if tc.isErr { 2280 | return 2281 | } 2282 | bodyBytes, _ := io.ReadAll(resp.Body) 2283 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 2284 | } 2285 | 2286 | // Check response body 2287 | var body map[string]interface{} 2288 | err = json.NewDecoder(resp.Body).Decode(&body) 2289 | if err != nil { 2290 | t.Fatalf("error parsing response body") 2291 | } 2292 | 2293 | got, ok := body["result"].(string) 2294 | if !ok { 2295 | t.Fatalf("unable to find result in response body") 2296 | } 2297 | 2298 | if !strings.Contains(got, tc.want) { 2299 | t.Fatalf("expected %q to contain %q, but it did not", got, tc.want) 2300 | } 2301 | }) 2302 | } 2303 | } 2304 | 2305 | func runBigQueryConversationalAnalyticsInvokeTest(t *testing.T, datasetName, tableName, dataInsightsWant string) { 2306 | // Each test is expected to complete in under 10s, we set a 25s timeout with retries to avoid flaky tests. 2307 | const maxRetries = 3 2308 | const requestTimeout = 25 * time.Second 2309 | // Get ID token 2310 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 2311 | if err != nil { 2312 | t.Fatalf("error getting Google ID token: %s", err) 2313 | } 2314 | 2315 | // Get access token 2316 | accessToken, err := sources.GetIAMAccessToken(t.Context()) 2317 | if err != nil { 2318 | t.Fatalf("error getting access token from ADC: %s", err) 2319 | } 2320 | accessToken = "Bearer " + accessToken 2321 | 2322 | tableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, datasetName, tableName) 2323 | 2324 | invokeTcs := []struct { 2325 | name string 2326 | api string 2327 | requestHeader map[string]string 2328 | requestBody io.Reader 2329 | want string 2330 | isErr bool 2331 | }{ 2332 | { 2333 | name: "invoke my-conversational-analytics-tool successfully", 2334 | api: "http://127.0.0.1:5000/api/tool/my-conversational-analytics-tool/invoke", 2335 | requestHeader: map[string]string{}, 2336 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf( 2337 | `{"user_query_with_context": "What are the names in the table?", "table_references": %q}`, 2338 | tableRefsJSON, 2339 | ))), 2340 | want: dataInsightsWant, 2341 | isErr: false, 2342 | }, 2343 | { 2344 | name: "invoke my-auth-conversational-analytics-tool with auth token", 2345 | api: "http://127.0.0.1:5000/api/tool/my-auth-conversational-analytics-tool/invoke", 2346 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2347 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf( 2348 | `{"user_query_with_context": "What are the names in the table?", "table_references": %q}`, 2349 | tableRefsJSON, 2350 | ))), 2351 | want: dataInsightsWant, 2352 | isErr: false, 2353 | }, 2354 | { 2355 | name: "invoke my-auth-conversational-analytics-tool without auth token", 2356 | api: "http://127.0.0.1:5000/api/tool/my-auth-conversational-analytics-tool/invoke", 2357 | requestHeader: map[string]string{}, 2358 | requestBody: bytes.NewBuffer([]byte(`{"user_query_with_context": "What are the names in the table?"}`)), 2359 | isErr: true, 2360 | }, 2361 | { 2362 | name: "Invoke my-client-auth-conversational-analytics-tool with auth token", 2363 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-conversational-analytics-tool/invoke", 2364 | requestHeader: map[string]string{"Authorization": accessToken}, 2365 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf( 2366 | `{"user_query_with_context": "What are the names in the table?", "table_references": %q}`, 2367 | tableRefsJSON, 2368 | ))), 2369 | want: "[{\"f0_\":1}]", 2370 | isErr: false, 2371 | }, 2372 | { 2373 | name: "Invoke my-client-auth-conversational-analytics-tool without auth token", 2374 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-conversational-analytics-tool/invoke", 2375 | requestHeader: map[string]string{}, 2376 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf( 2377 | `{"user_query_with_context": "What are the names in the table?", "table_references": %q}`, 2378 | tableRefsJSON, 2379 | ))), 2380 | isErr: true, 2381 | }, 2382 | { 2383 | 2384 | name: "Invoke my-client-auth-conversational-analytics-tool with invalid auth token", 2385 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-conversational-analytics-tool/invoke", 2386 | requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, 2387 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf( 2388 | `{"user_query_with_context": "What are the names in the table?", "table_references": %q}`, 2389 | tableRefsJSON, 2390 | ))), 2391 | isErr: true, 2392 | }, 2393 | } 2394 | for _, tc := range invokeTcs { 2395 | t.Run(tc.name, func(t *testing.T) { 2396 | var resp *http.Response 2397 | var err error 2398 | 2399 | bodyBytes, err := io.ReadAll(tc.requestBody) 2400 | if err != nil { 2401 | t.Fatalf("failed to read request body: %v", err) 2402 | } 2403 | 2404 | req, err := http.NewRequest(http.MethodPost, tc.api, nil) 2405 | if err != nil { 2406 | t.Fatalf("unable to create request: %s", err) 2407 | } 2408 | req.Header.Set("Content-type", "application/json") 2409 | for k, v := range tc.requestHeader { 2410 | req.Header.Add(k, v) 2411 | } 2412 | 2413 | for i := 0; i < maxRetries; i++ { 2414 | ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) 2415 | defer cancel() 2416 | 2417 | req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) 2418 | req.GetBody = func() (io.ReadCloser, error) { 2419 | return io.NopCloser(bytes.NewReader(bodyBytes)), nil 2420 | } 2421 | reqWithCtx := req.WithContext(ctx) 2422 | 2423 | resp, err = http.DefaultClient.Do(reqWithCtx) 2424 | if err != nil { 2425 | // Retry on time out. 2426 | if os.IsTimeout(err) { 2427 | t.Logf("Request timed out (attempt %d/%d), retrying...", i+1, maxRetries) 2428 | time.Sleep(5 * time.Second) 2429 | continue 2430 | } 2431 | t.Fatalf("unable to send request: %s", err) 2432 | } 2433 | if resp.StatusCode == http.StatusServiceUnavailable { 2434 | t.Logf("Received 503 Service Unavailable (attempt %d/%d), retrying...", i+1, maxRetries) 2435 | time.Sleep(15 * time.Second) 2436 | continue 2437 | } 2438 | break 2439 | } 2440 | 2441 | if err != nil { 2442 | t.Fatalf("Request failed after %d retries: %v", maxRetries, err) 2443 | } 2444 | defer resp.Body.Close() 2445 | 2446 | if resp.StatusCode != http.StatusOK { 2447 | if tc.isErr { 2448 | return 2449 | } 2450 | bodyBytes, _ := io.ReadAll(resp.Body) 2451 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 2452 | } 2453 | 2454 | var body map[string]interface{} 2455 | err = json.NewDecoder(resp.Body).Decode(&body) 2456 | if err != nil { 2457 | t.Fatalf("error parsing response body: %v", err) 2458 | } 2459 | 2460 | got, ok := body["result"].(string) 2461 | if !ok { 2462 | t.Fatalf("unable to find result in response body") 2463 | } 2464 | 2465 | wantPattern := regexp.MustCompile(tc.want) 2466 | if !wantPattern.MatchString(got) { 2467 | t.Fatalf("response did not match the expected pattern.\nFull response:\n%s", got) 2468 | } 2469 | }) 2470 | } 2471 | } 2472 | 2473 | func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowedDatasetName2 string) { 2474 | testCases := []struct { 2475 | name string 2476 | wantStatusCode int 2477 | wantElements []string 2478 | }{ 2479 | { 2480 | name: "invoke list-dataset-ids with restriction", 2481 | wantStatusCode: http.StatusOK, 2482 | wantElements: []string{ 2483 | fmt.Sprintf("%s.%s", BigqueryProject, allowedDatasetName1), 2484 | fmt.Sprintf("%s.%s", BigqueryProject, allowedDatasetName2), 2485 | }, 2486 | }, 2487 | } 2488 | 2489 | for _, tc := range testCases { 2490 | t.Run(tc.name, func(t *testing.T) { 2491 | body := bytes.NewBuffer([]byte(`{}`)) 2492 | resp, bodyBytes := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/api/tool/list-dataset-ids-restricted/invoke", body, nil) 2493 | 2494 | if resp.StatusCode != tc.wantStatusCode { 2495 | t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) 2496 | } 2497 | 2498 | var respBody map[string]interface{} 2499 | if err := json.Unmarshal(bodyBytes, &respBody); err != nil { 2500 | t.Fatalf("error parsing response body: %v", err) 2501 | } 2502 | 2503 | gotJSON, ok := respBody["result"].(string) 2504 | if !ok { 2505 | t.Fatalf("unable to find 'result' as a string in response body: %s", string(bodyBytes)) 2506 | } 2507 | 2508 | // Unmarshal the result string into a slice to compare contents. 2509 | var gotElements []string 2510 | if err := json.Unmarshal([]byte(gotJSON), &gotElements); err != nil { 2511 | t.Fatalf("error parsing result field JSON %q: %v", gotJSON, err) 2512 | } 2513 | 2514 | sort.Strings(gotElements) 2515 | sort.Strings(tc.wantElements) 2516 | if !reflect.DeepEqual(gotElements, tc.wantElements) { 2517 | t.Errorf("unexpected result:\n got: %v\nwant: %v", gotElements, tc.wantElements) 2518 | } 2519 | }) 2520 | } 2521 | } 2522 | 2523 | func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string, allowedTableNames ...string) { 2524 | sort.Strings(allowedTableNames) 2525 | var quotedNames []string 2526 | for _, name := range allowedTableNames { 2527 | quotedNames = append(quotedNames, fmt.Sprintf(`"%s"`, name)) 2528 | } 2529 | wantResult := fmt.Sprintf(`[%s]`, strings.Join(quotedNames, ",")) 2530 | 2531 | testCases := []struct { 2532 | name string 2533 | dataset string 2534 | wantStatusCode int 2535 | wantInResult string 2536 | wantInError string 2537 | }{ 2538 | { 2539 | name: "invoke on allowed dataset", 2540 | dataset: allowedDatasetName, 2541 | wantStatusCode: http.StatusOK, 2542 | wantInResult: wantResult, 2543 | }, 2544 | { 2545 | name: "invoke on disallowed dataset", 2546 | dataset: disallowedDatasetName, 2547 | wantStatusCode: http.StatusBadRequest, // Or the specific error code returned 2548 | wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName), 2549 | }, 2550 | } 2551 | 2552 | for _, tc := range testCases { 2553 | t.Run(tc.name, func(t *testing.T) { 2554 | body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s"}`, tc.dataset))) 2555 | req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/list-table-ids-restricted/invoke", body) 2556 | if err != nil { 2557 | t.Fatalf("unable to create request: %s", err) 2558 | } 2559 | req.Header.Add("Content-type", "application/json") 2560 | resp, err := http.DefaultClient.Do(req) 2561 | if err != nil { 2562 | t.Fatalf("unable to send request: %s", err) 2563 | } 2564 | defer resp.Body.Close() 2565 | 2566 | if resp.StatusCode != tc.wantStatusCode { 2567 | bodyBytes, _ := io.ReadAll(resp.Body) 2568 | t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) 2569 | } 2570 | 2571 | if tc.wantInResult != "" { 2572 | var respBody map[string]interface{} 2573 | if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { 2574 | t.Fatalf("error parsing response body: %v", err) 2575 | } 2576 | got, ok := respBody["result"].(string) 2577 | if !ok { 2578 | t.Fatalf("unable to find result in response body") 2579 | } 2580 | 2581 | var gotSlice []string 2582 | if err := json.Unmarshal([]byte(got), &gotSlice); err != nil { 2583 | t.Fatalf("error unmarshalling result: %v", err) 2584 | } 2585 | sort.Strings(gotSlice) 2586 | sortedGotBytes, err := json.Marshal(gotSlice) 2587 | if err != nil { 2588 | t.Fatalf("error marshalling sorted result: %v", err) 2589 | } 2590 | 2591 | if string(sortedGotBytes) != tc.wantInResult { 2592 | t.Errorf("unexpected result: got %q, want %q", string(sortedGotBytes), tc.wantInResult) 2593 | } 2594 | } 2595 | 2596 | if tc.wantInError != "" { 2597 | bodyBytes, _ := io.ReadAll(resp.Body) 2598 | if !strings.Contains(string(bodyBytes), tc.wantInError) { 2599 | t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) 2600 | } 2601 | } 2602 | }) 2603 | } 2604 | } 2605 | 2606 | func runGetTableInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) { 2607 | testCases := []struct { 2608 | name string 2609 | dataset string 2610 | table string 2611 | wantStatusCode int 2612 | wantInError string 2613 | }{ 2614 | { 2615 | name: "invoke on allowed table", 2616 | dataset: allowedDatasetName, 2617 | table: allowedTableName, 2618 | wantStatusCode: http.StatusOK, 2619 | }, 2620 | { 2621 | name: "invoke on disallowed table", 2622 | dataset: disallowedDatasetName, 2623 | table: disallowedTableName, 2624 | wantStatusCode: http.StatusBadRequest, 2625 | wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName), 2626 | }, 2627 | } 2628 | 2629 | for _, tc := range testCases { 2630 | t.Run(tc.name, func(t *testing.T) { 2631 | body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s", "table":"%s"}`, tc.dataset, tc.table))) 2632 | req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/get-table-info-restricted/invoke", body) 2633 | if err != nil { 2634 | t.Fatalf("unable to create request: %s", err) 2635 | } 2636 | req.Header.Add("Content-type", "application/json") 2637 | resp, err := http.DefaultClient.Do(req) 2638 | if err != nil { 2639 | t.Fatalf("unable to send request: %s", err) 2640 | } 2641 | defer resp.Body.Close() 2642 | 2643 | if resp.StatusCode != tc.wantStatusCode { 2644 | bodyBytes, _ := io.ReadAll(resp.Body) 2645 | t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) 2646 | } 2647 | 2648 | if tc.wantInError != "" { 2649 | bodyBytes, _ := io.ReadAll(resp.Body) 2650 | if !strings.Contains(string(bodyBytes), tc.wantInError) { 2651 | t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) 2652 | } 2653 | } 2654 | }) 2655 | } 2656 | } 2657 | 2658 | func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) { 2659 | allowedTableParts := strings.Split(strings.Trim(allowedTableFullName, "`"), ".") 2660 | if len(allowedTableParts) != 3 { 2661 | t.Fatalf("invalid allowed table name format: %s", allowedTableFullName) 2662 | } 2663 | allowedDatasetID := allowedTableParts[1] 2664 | 2665 | testCases := []struct { 2666 | name string 2667 | sql string 2668 | wantStatusCode int 2669 | wantInError string 2670 | }{ 2671 | { 2672 | name: "invoke on allowed table", 2673 | sql: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName), 2674 | wantStatusCode: http.StatusOK, 2675 | }, 2676 | { 2677 | name: "invoke on disallowed table", 2678 | sql: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), 2679 | wantStatusCode: http.StatusBadRequest, 2680 | wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list", 2681 | strings.Join( 2682 | strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0:2], 2683 | ".")), 2684 | }, 2685 | { 2686 | name: "disallowed create schema", 2687 | sql: "CREATE SCHEMA another_dataset", 2688 | wantStatusCode: http.StatusBadRequest, 2689 | wantInError: "dataset-level operations like 'CREATE_SCHEMA' are not allowed", 2690 | }, 2691 | { 2692 | name: "disallowed alter schema", 2693 | sql: fmt.Sprintf("ALTER SCHEMA %s SET OPTIONS(description='new one')", allowedDatasetID), 2694 | wantStatusCode: http.StatusBadRequest, 2695 | wantInError: "dataset-level operations like 'ALTER_SCHEMA' are not allowed", 2696 | }, 2697 | { 2698 | name: "disallowed create function", 2699 | sql: fmt.Sprintf("CREATE FUNCTION %s.my_func() RETURNS INT64 AS (1)", allowedDatasetID), 2700 | wantStatusCode: http.StatusBadRequest, 2701 | wantInError: "creating stored routines ('CREATE_FUNCTION') is not allowed", 2702 | }, 2703 | { 2704 | name: "disallowed create procedure", 2705 | sql: fmt.Sprintf("CREATE PROCEDURE %s.my_proc() BEGIN SELECT 1; END", allowedDatasetID), 2706 | wantStatusCode: http.StatusBadRequest, 2707 | wantInError: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed", 2708 | }, 2709 | { 2710 | name: "disallowed execute immediate", 2711 | sql: "EXECUTE IMMEDIATE 'SELECT 1'", 2712 | wantStatusCode: http.StatusBadRequest, 2713 | wantInError: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place", 2714 | }, 2715 | } 2716 | 2717 | for _, tc := range testCases { 2718 | t.Run(tc.name, func(t *testing.T) { 2719 | body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"%s"}`, tc.sql))) 2720 | req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/execute-sql-restricted/invoke", body) 2721 | if err != nil { 2722 | t.Fatalf("unable to create request: %s", err) 2723 | } 2724 | req.Header.Add("Content-type", "application/json") 2725 | resp, err := http.DefaultClient.Do(req) 2726 | if err != nil { 2727 | t.Fatalf("unable to send request: %s", err) 2728 | } 2729 | defer resp.Body.Close() 2730 | 2731 | if resp.StatusCode != tc.wantStatusCode { 2732 | bodyBytes, _ := io.ReadAll(resp.Body) 2733 | t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) 2734 | } 2735 | 2736 | if tc.wantInError != "" { 2737 | bodyBytes, _ := io.ReadAll(resp.Body) 2738 | if !strings.Contains(string(bodyBytes), tc.wantInError) { 2739 | t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) 2740 | } 2741 | } 2742 | }) 2743 | } 2744 | } 2745 | 2746 | func runConversationalAnalyticsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) { 2747 | allowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, allowedDatasetName, allowedTableName) 2748 | disallowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, disallowedDatasetName, disallowedTableName) 2749 | 2750 | testCases := []struct { 2751 | name string 2752 | tableRefs string 2753 | wantStatusCode int 2754 | wantInResult string 2755 | wantInError string 2756 | }{ 2757 | { 2758 | name: "invoke with allowed table", 2759 | tableRefs: allowedTableRefsJSON, 2760 | wantStatusCode: http.StatusOK, 2761 | wantInResult: `Answer`, 2762 | }, 2763 | { 2764 | name: "invoke with disallowed table", 2765 | tableRefs: disallowedTableRefsJSON, 2766 | wantStatusCode: http.StatusBadRequest, 2767 | wantInError: fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", BigqueryProject, disallowedDatasetName, disallowedTableName), 2768 | }, 2769 | } 2770 | 2771 | for _, tc := range testCases { 2772 | t.Run(tc.name, func(t *testing.T) { 2773 | requestBodyMap := map[string]any{ 2774 | "user_query_with_context": "What is in the table?", 2775 | "table_references": tc.tableRefs, 2776 | } 2777 | bodyBytes, err := json.Marshal(requestBodyMap) 2778 | if err != nil { 2779 | t.Fatalf("failed to marshal request body: %v", err) 2780 | } 2781 | body := bytes.NewBuffer(bodyBytes) 2782 | 2783 | req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/conversational-analytics-restricted/invoke", body) 2784 | if err != nil { 2785 | t.Fatalf("unable to create request: %s", err) 2786 | } 2787 | req.Header.Add("Content-type", "application/json") 2788 | resp, err := http.DefaultClient.Do(req) 2789 | if err != nil { 2790 | t.Fatalf("unable to send request: %s", err) 2791 | } 2792 | defer resp.Body.Close() 2793 | 2794 | if resp.StatusCode != tc.wantStatusCode { 2795 | bodyBytes, _ := io.ReadAll(resp.Body) 2796 | t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) 2797 | } 2798 | 2799 | if tc.wantInResult != "" { 2800 | var respBody map[string]interface{} 2801 | if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { 2802 | t.Fatalf("error parsing response body: %v", err) 2803 | } 2804 | got, ok := respBody["result"].(string) 2805 | if !ok { 2806 | t.Fatalf("unable to find result in response body") 2807 | } 2808 | if !strings.Contains(got, tc.wantInResult) { 2809 | t.Errorf("unexpected result: got %q, want to contain %q", got, tc.wantInResult) 2810 | } 2811 | } 2812 | 2813 | if tc.wantInError != "" { 2814 | bodyBytes, _ := io.ReadAll(resp.Body) 2815 | if !strings.Contains(string(bodyBytes), tc.wantInError) { 2816 | t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) 2817 | } 2818 | } 2819 | }) 2820 | } 2821 | } 2822 | 2823 | func runBigQuerySearchCatalogToolInvokeTest(t *testing.T, datasetName string, tableName string) { 2824 | // Get ID token 2825 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 2826 | if err != nil { 2827 | t.Fatalf("error getting Google ID token: %s", err) 2828 | } 2829 | 2830 | // Get access token 2831 | accessToken, err := sources.GetIAMAccessToken(t.Context()) 2832 | if err != nil { 2833 | t.Fatalf("error getting access token from ADC: %s", err) 2834 | } 2835 | accessToken = "Bearer " + accessToken 2836 | 2837 | // Test tool invoke endpoint 2838 | invokeTcs := []struct { 2839 | name string 2840 | api string 2841 | requestHeader map[string]string 2842 | requestBody io.Reader 2843 | wantKey string 2844 | isErr bool 2845 | }{ 2846 | { 2847 | name: "invoke my-search-catalog-tool without body", 2848 | api: "http://127.0.0.1:5000/api/tool/my-search-catalog-tool/invoke", 2849 | requestHeader: map[string]string{}, 2850 | requestBody: bytes.NewBuffer([]byte(`{}`)), 2851 | isErr: true, 2852 | }, 2853 | { 2854 | name: "invoke my-search-catalog-tool", 2855 | api: "http://127.0.0.1:5000/api/tool/my-search-catalog-tool/invoke", 2856 | requestHeader: map[string]string{}, 2857 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"datasetIds\":[\"%s\"]}", tableName, datasetName))), 2858 | wantKey: "DisplayName", 2859 | isErr: false, 2860 | }, 2861 | { 2862 | name: "Invoke my-auth-search-catalog-tool with auth token", 2863 | api: "http://127.0.0.1:5000/api/tool/my-auth-search-catalog-tool/invoke", 2864 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2865 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"datasetIds\":[\"%s\"]}", tableName, datasetName))), 2866 | wantKey: "DisplayName", 2867 | isErr: false, 2868 | }, 2869 | { 2870 | name: "Invoke my-auth-search-catalog-tool with correct project", 2871 | api: "http://127.0.0.1:5000/api/tool/my-auth-search-catalog-tool/invoke", 2872 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2873 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"projectIds\":[\"%s\"], \"datasetIds\":[\"%s\"]}", tableName, BigqueryProject, datasetName))), 2874 | wantKey: "DisplayName", 2875 | isErr: false, 2876 | }, 2877 | { 2878 | name: "Invoke my-auth-search-catalog-tool with non-existent project", 2879 | api: "http://127.0.0.1:5000/api/tool/my-auth-search-catalog-tool/invoke", 2880 | requestHeader: map[string]string{"my-google-auth_token": idToken}, 2881 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"projectIds\":[\"%s-%s\"], \"datasetIds\":[\"%s\"]}", tableName, BigqueryProject, uuid.NewString(), datasetName))), 2882 | isErr: true, 2883 | }, 2884 | { 2885 | name: "Invoke my-auth-search-catalog-tool with invalid auth token", 2886 | api: "http://127.0.0.1:5000/api/tool/my-auth-search-catalog-tool/invoke", 2887 | requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, 2888 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"datasetIds\":[\"%s\"]}", tableName, datasetName))), 2889 | isErr: true, 2890 | }, 2891 | { 2892 | name: "Invoke my-auth-search-catalog-tool without auth token", 2893 | api: "http://127.0.0.1:5000/api/tool/my-auth-search-catalog-tool/invoke", 2894 | requestHeader: map[string]string{}, 2895 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"datasetIds\":[\"%s\"]}", tableName, datasetName))), 2896 | isErr: true, 2897 | }, 2898 | { 2899 | name: "Invoke my-client-auth-search-catalog-tool without auth token", 2900 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-search-catalog-tool/invoke", 2901 | requestHeader: map[string]string{}, 2902 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"datasetIds\":[\"%s\"]}", tableName, datasetName))), 2903 | isErr: true, 2904 | }, 2905 | { 2906 | name: "Invoke my-client-auth-search-catalog-tool with auth token", 2907 | api: "http://127.0.0.1:5000/api/tool/my-client-auth-search-catalog-tool/invoke", 2908 | requestHeader: map[string]string{"Authorization": accessToken}, 2909 | requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"prompt\":\"%s\", \"types\":[\"TABLE\"], \"datasetIds\":[\"%s\"]}", tableName, datasetName))), 2910 | wantKey: "DisplayName", 2911 | isErr: false, 2912 | }, 2913 | } 2914 | for _, tc := range invokeTcs { 2915 | t.Run(tc.name, func(t *testing.T) { 2916 | // Send Tool invocation request 2917 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 2918 | if err != nil { 2919 | t.Fatalf("unable to create request: %s", err) 2920 | } 2921 | req.Header.Add("Content-type", "application/json") 2922 | for k, v := range tc.requestHeader { 2923 | req.Header.Add(k, v) 2924 | } 2925 | resp, err := http.DefaultClient.Do(req) 2926 | if err != nil { 2927 | t.Fatalf("unable to send request: %s", err) 2928 | } 2929 | defer resp.Body.Close() 2930 | 2931 | if resp.StatusCode != http.StatusOK { 2932 | if tc.isErr { 2933 | return 2934 | } 2935 | bodyBytes, _ := io.ReadAll(resp.Body) 2936 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 2937 | } 2938 | 2939 | var result map[string]interface{} 2940 | if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { 2941 | t.Fatalf("error parsing response body: %s", err) 2942 | } 2943 | resultStr, ok := result["result"].(string) 2944 | if !ok { 2945 | if result["result"] == nil && tc.isErr { 2946 | return 2947 | } 2948 | t.Fatalf("expected 'result' field to be a string, got %T", result["result"]) 2949 | } 2950 | if tc.isErr && (resultStr == "" || resultStr == "[]") { 2951 | return 2952 | } 2953 | var entries []interface{} 2954 | if err := json.Unmarshal([]byte(resultStr), &entries); err != nil { 2955 | t.Fatalf("error unmarshalling result string: %v", err) 2956 | } 2957 | 2958 | if !tc.isErr { 2959 | if len(entries) != 1 { 2960 | t.Fatalf("expected exactly one entry, but got %d", len(entries)) 2961 | } 2962 | entry, ok := entries[0].(map[string]interface{}) 2963 | if !ok { 2964 | t.Fatalf("expected first entry to be a map, got %T", entries[0]) 2965 | } 2966 | respTable, ok := entry[tc.wantKey] 2967 | if !ok { 2968 | t.Fatalf("expected entry to have key '%s', but it was not found in %v", tc.wantKey, entry) 2969 | } 2970 | if respTable != tableName { 2971 | t.Fatalf("expected key '%s' to have value '%s', but got %s", tc.wantKey, tableName, respTable) 2972 | } 2973 | } else { 2974 | if len(entries) != 0 { 2975 | t.Fatalf("expected 0 entries, but got %d", len(entries)) 2976 | } 2977 | } 2978 | }) 2979 | } 2980 | } 2981 | 2982 | func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) { 2983 | allowedTableUnquoted := strings.ReplaceAll(allowedTableFullName, "`", "") 2984 | disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "") 2985 | disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".") 2986 | 2987 | testCases := []struct { 2988 | name string 2989 | historyData string 2990 | wantStatusCode int 2991 | wantInResult string 2992 | wantInError string 2993 | }{ 2994 | { 2995 | name: "invoke with allowed table name", 2996 | historyData: allowedTableUnquoted, 2997 | wantStatusCode: http.StatusOK, 2998 | wantInResult: `"forecast_timestamp"`, 2999 | }, 3000 | { 3001 | name: "invoke with disallowed table name", 3002 | historyData: disallowedTableUnquoted, 3003 | wantStatusCode: http.StatusBadRequest, 3004 | wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted), 3005 | }, 3006 | { 3007 | name: "invoke with query on allowed table", 3008 | historyData: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName), 3009 | wantStatusCode: http.StatusOK, 3010 | wantInResult: `"forecast_timestamp"`, 3011 | }, 3012 | { 3013 | name: "invoke with query on disallowed table", 3014 | historyData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), 3015 | wantStatusCode: http.StatusBadRequest, 3016 | wantInError: fmt.Sprintf("query in history_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), 3017 | }, 3018 | } 3019 | 3020 | for _, tc := range testCases { 3021 | t.Run(tc.name, func(t *testing.T) { 3022 | requestBodyMap := map[string]any{ 3023 | "history_data": tc.historyData, 3024 | "timestamp_col": "ts", 3025 | "data_col": "data", 3026 | } 3027 | bodyBytes, err := json.Marshal(requestBodyMap) 3028 | if err != nil { 3029 | t.Fatalf("failed to marshal request body: %v", err) 3030 | } 3031 | body := bytes.NewBuffer(bodyBytes) 3032 | 3033 | req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/forecast-restricted/invoke", body) 3034 | if err != nil { 3035 | t.Fatalf("unable to create request: %s", err) 3036 | } 3037 | req.Header.Add("Content-type", "application/json") 3038 | resp, err := http.DefaultClient.Do(req) 3039 | if err != nil { 3040 | t.Fatalf("unable to send request: %s", err) 3041 | } 3042 | defer resp.Body.Close() 3043 | 3044 | if resp.StatusCode != tc.wantStatusCode { 3045 | bodyBytes, _ := io.ReadAll(resp.Body) 3046 | t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) 3047 | } 3048 | 3049 | if tc.wantInResult != "" { 3050 | var respBody map[string]interface{} 3051 | if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { 3052 | t.Fatalf("error parsing response body: %v", err) 3053 | } 3054 | got, ok := respBody["result"].(string) 3055 | if !ok { 3056 | t.Fatalf("unable to find result in response body") 3057 | } 3058 | if !strings.Contains(got, tc.wantInResult) { 3059 | t.Errorf("unexpected result: got %q, want to contain %q", got, tc.wantInResult) 3060 | } 3061 | } 3062 | 3063 | if tc.wantInError != "" { 3064 | bodyBytes, _ := io.ReadAll(resp.Body) 3065 | if !strings.Contains(string(bodyBytes), tc.wantInError) { 3066 | t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) 3067 | } 3068 | } 3069 | }) 3070 | } 3071 | } 3072 | ```