This is page 20 of 35. Use http://codebase.md/googleapis/genai-toolbox?lines=false&page={x} to view the full context. # Directory Structure ``` ├── .ci │ ├── continuous.release.cloudbuild.yaml │ ├── generate_release_table.sh │ ├── integration.cloudbuild.yaml │ ├── quickstart_test │ │ ├── go.integration.cloudbuild.yaml │ │ ├── js.integration.cloudbuild.yaml │ │ ├── py.integration.cloudbuild.yaml │ │ ├── run_go_tests.sh │ │ ├── run_js_tests.sh │ │ ├── run_py_tests.sh │ │ └── setup_hotels_sample.sql │ ├── test_with_coverage.sh │ └── versioned.release.cloudbuild.yaml ├── .github │ ├── auto-label.yaml │ ├── blunderbuss.yml │ ├── CODEOWNERS │ ├── header-checker-lint.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ ├── config.yml │ │ ├── feature_request.yml │ │ └── question.yml │ ├── label-sync.yml │ ├── labels.yaml │ ├── PULL_REQUEST_TEMPLATE.md │ ├── release-please.yml │ ├── renovate.json5 │ ├── sync-repo-settings.yaml │ └── workflows │ ├── cloud_build_failure_reporter.yml │ ├── deploy_dev_docs.yaml │ ├── deploy_previous_version_docs.yaml │ ├── deploy_versioned_docs.yaml │ ├── docs_deploy.yaml │ ├── docs_preview_clean.yaml │ ├── docs_preview_deploy.yaml │ ├── lint.yaml │ ├── schedule_reporter.yml │ ├── sync-labels.yaml │ └── tests.yaml ├── .gitignore ├── .gitmodules ├── .golangci.yaml ├── .hugo │ ├── archetypes │ │ └── default.md │ ├── assets │ │ ├── icons │ │ │ └── logo.svg │ │ └── scss │ │ ├── _styles_project.scss │ │ └── _variables_project.scss │ ├── go.mod │ ├── go.sum │ ├── hugo.toml │ ├── layouts │ │ ├── _default │ │ │ └── home.releases.releases │ │ ├── index.llms-full.txt │ │ ├── index.llms.txt │ │ ├── partials │ │ │ ├── hooks │ │ │ │ └── head-end.html │ │ │ ├── navbar-version-selector.html │ │ │ ├── page-meta-links.html │ │ │ └── td │ │ │ └── render-heading.html │ │ ├── robot.txt │ │ └── shortcodes │ │ ├── include.html │ │ ├── ipynb.html │ │ └── regionInclude.html │ ├── package-lock.json │ ├── package.json │ └── static │ ├── favicons │ │ ├── android-chrome-192x192.png │ │ ├── android-chrome-512x512.png │ │ ├── apple-touch-icon.png │ │ ├── favicon-16x16.png │ │ ├── favicon-32x32.png │ │ └── favicon.ico │ └── js │ └── w3.js ├── CHANGELOG.md ├── cmd │ ├── options_test.go │ ├── options.go │ ├── root_test.go │ ├── root.go │ └── version.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DEVELOPER.md ├── Dockerfile ├── docs │ └── en │ ├── _index.md │ ├── about │ │ ├── _index.md │ │ └── faq.md │ ├── concepts │ │ ├── _index.md │ │ └── telemetry │ │ ├── index.md │ │ ├── telemetry_flow.png │ │ └── telemetry_traces.png │ ├── getting-started │ │ ├── _index.md │ │ ├── colab_quickstart.ipynb │ │ ├── configure.md │ │ ├── introduction │ │ │ ├── _index.md │ │ │ └── architecture.png │ │ ├── local_quickstart_go.md │ │ ├── local_quickstart_js.md │ │ ├── local_quickstart.md │ │ ├── mcp_quickstart │ │ │ ├── _index.md │ │ │ ├── inspector_tools.png │ │ │ └── inspector.png │ │ └── quickstart │ │ ├── go │ │ │ ├── genAI │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ ├── genkit │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ ├── langchain │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ ├── openAI │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ └── quickstart_test.go │ │ ├── golden.txt │ │ ├── js │ │ │ ├── genAI │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ ├── genkit │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ ├── langchain │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ ├── llamaindex │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ └── quickstart.test.js │ │ ├── python │ │ │ ├── __init__.py │ │ │ ├── adk │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ ├── core │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ ├── langchain │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ ├── llamaindex │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ └── quickstart_test.py │ │ └── shared │ │ ├── cloud_setup.md │ │ ├── configure_toolbox.md │ │ └── database_setup.md │ ├── how-to │ │ ├── _index.md │ │ ├── connect_via_geminicli.md │ │ ├── connect_via_mcp.md │ │ ├── connect-ide │ │ │ ├── _index.md │ │ │ ├── alloydb_pg_admin_mcp.md │ │ │ ├── alloydb_pg_mcp.md │ │ │ ├── bigquery_mcp.md │ │ │ ├── cloud_sql_mssql_admin_mcp.md │ │ │ ├── cloud_sql_mssql_mcp.md │ │ │ ├── cloud_sql_mysql_admin_mcp.md │ │ │ ├── cloud_sql_mysql_mcp.md │ │ │ ├── cloud_sql_pg_admin_mcp.md │ │ │ ├── cloud_sql_pg_mcp.md │ │ │ ├── firestore_mcp.md │ │ │ ├── looker_mcp.md │ │ │ ├── mssql_mcp.md │ │ │ ├── mysql_mcp.md │ │ │ ├── neo4j_mcp.md │ │ │ ├── postgres_mcp.md │ │ │ ├── spanner_mcp.md │ │ │ └── sqlite_mcp.md │ │ ├── deploy_docker.md │ │ ├── deploy_gke.md │ │ ├── deploy_toolbox.md │ │ ├── export_telemetry.md │ │ └── toolbox-ui │ │ ├── edit-headers.gif │ │ ├── edit-headers.png │ │ ├── index.md │ │ ├── optional-param-checked.png │ │ ├── optional-param-unchecked.png │ │ ├── run-tool.gif │ │ ├── tools.png │ │ └── toolsets.png │ ├── reference │ │ ├── _index.md │ │ ├── cli.md │ │ └── prebuilt-tools.md │ ├── resources │ │ ├── _index.md │ │ ├── authServices │ │ │ ├── _index.md │ │ │ └── google.md │ │ ├── sources │ │ │ ├── _index.md │ │ │ ├── alloydb-admin.md │ │ │ ├── alloydb-pg.md │ │ │ ├── bigquery.md │ │ │ ├── bigtable.md │ │ │ ├── cassandra.md │ │ │ ├── clickhouse.md │ │ │ ├── cloud-monitoring.md │ │ │ ├── cloud-sql-admin.md │ │ │ ├── cloud-sql-mssql.md │ │ │ ├── cloud-sql-mysql.md │ │ │ ├── cloud-sql-pg.md │ │ │ ├── couchbase.md │ │ │ ├── dataplex.md │ │ │ ├── dgraph.md │ │ │ ├── firebird.md │ │ │ ├── firestore.md │ │ │ ├── http.md │ │ │ ├── looker.md │ │ │ ├── mongodb.md │ │ │ ├── mssql.md │ │ │ ├── mysql.md │ │ │ ├── neo4j.md │ │ │ ├── oceanbase.md │ │ │ ├── oracle.md │ │ │ ├── postgres.md │ │ │ ├── redis.md │ │ │ ├── spanner.md │ │ │ ├── sqlite.md │ │ │ ├── tidb.md │ │ │ ├── trino.md │ │ │ ├── valkey.md │ │ │ └── yugabytedb.md │ │ └── tools │ │ ├── _index.md │ │ ├── alloydb │ │ │ ├── _index.md │ │ │ ├── alloydb-create-cluster.md │ │ │ ├── alloydb-create-instance.md │ │ │ ├── alloydb-create-user.md │ │ │ ├── alloydb-get-cluster.md │ │ │ ├── alloydb-get-instance.md │ │ │ ├── alloydb-get-user.md │ │ │ ├── alloydb-list-clusters.md │ │ │ ├── alloydb-list-instances.md │ │ │ ├── alloydb-list-users.md │ │ │ └── alloydb-wait-for-operation.md │ │ ├── alloydbainl │ │ │ ├── _index.md │ │ │ └── alloydb-ai-nl.md │ │ ├── bigquery │ │ │ ├── _index.md │ │ │ ├── bigquery-analyze-contribution.md │ │ │ ├── bigquery-conversational-analytics.md │ │ │ ├── bigquery-execute-sql.md │ │ │ ├── bigquery-forecast.md │ │ │ ├── bigquery-get-dataset-info.md │ │ │ ├── bigquery-get-table-info.md │ │ │ ├── bigquery-list-dataset-ids.md │ │ │ ├── bigquery-list-table-ids.md │ │ │ ├── bigquery-search-catalog.md │ │ │ └── bigquery-sql.md │ │ ├── bigtable │ │ │ ├── _index.md │ │ │ └── bigtable-sql.md │ │ ├── cassandra │ │ │ ├── _index.md │ │ │ └── cassandra-cql.md │ │ ├── clickhouse │ │ │ ├── _index.md │ │ │ ├── clickhouse-execute-sql.md │ │ │ ├── clickhouse-list-databases.md │ │ │ ├── clickhouse-list-tables.md │ │ │ └── clickhouse-sql.md │ │ ├── cloudmonitoring │ │ │ ├── _index.md │ │ │ └── cloud-monitoring-query-prometheus.md │ │ ├── cloudsql │ │ │ ├── _index.md │ │ │ ├── cloudsqlcreatedatabase.md │ │ │ ├── cloudsqlcreateusers.md │ │ │ ├── cloudsqlgetinstances.md │ │ │ ├── cloudsqllistdatabases.md │ │ │ ├── cloudsqllistinstances.md │ │ │ ├── cloudsqlmssqlcreateinstance.md │ │ │ ├── cloudsqlmysqlcreateinstance.md │ │ │ ├── cloudsqlpgcreateinstances.md │ │ │ └── cloudsqlwaitforoperation.md │ │ ├── couchbase │ │ │ ├── _index.md │ │ │ └── couchbase-sql.md │ │ ├── dataform │ │ │ ├── _index.md │ │ │ └── dataform-compile-local.md │ │ ├── dataplex │ │ │ ├── _index.md │ │ │ ├── dataplex-lookup-entry.md │ │ │ ├── dataplex-search-aspect-types.md │ │ │ └── dataplex-search-entries.md │ │ ├── dgraph │ │ │ ├── _index.md │ │ │ └── dgraph-dql.md │ │ ├── firebird │ │ │ ├── _index.md │ │ │ ├── firebird-execute-sql.md │ │ │ └── firebird-sql.md │ │ ├── firestore │ │ │ ├── _index.md │ │ │ ├── firestore-add-documents.md │ │ │ ├── firestore-delete-documents.md │ │ │ ├── firestore-get-documents.md │ │ │ ├── firestore-get-rules.md │ │ │ ├── firestore-list-collections.md │ │ │ ├── firestore-query-collection.md │ │ │ ├── firestore-query.md │ │ │ ├── firestore-update-document.md │ │ │ └── firestore-validate-rules.md │ │ ├── http │ │ │ ├── _index.md │ │ │ └── http.md │ │ ├── looker │ │ │ ├── _index.md │ │ │ ├── looker-add-dashboard-element.md │ │ │ ├── looker-conversational-analytics.md │ │ │ ├── looker-create-project-file.md │ │ │ ├── looker-delete-project-file.md │ │ │ ├── looker-dev-mode.md │ │ │ ├── looker-get-dashboards.md │ │ │ ├── looker-get-dimensions.md │ │ │ ├── looker-get-explores.md │ │ │ ├── looker-get-filters.md │ │ │ ├── looker-get-looks.md │ │ │ ├── looker-get-measures.md │ │ │ ├── looker-get-models.md │ │ │ ├── looker-get-parameters.md │ │ │ ├── looker-get-project-file.md │ │ │ ├── looker-get-project-files.md │ │ │ ├── looker-get-projects.md │ │ │ ├── looker-health-analyze.md │ │ │ ├── looker-health-pulse.md │ │ │ ├── looker-health-vacuum.md │ │ │ ├── looker-make-dashboard.md │ │ │ ├── looker-make-look.md │ │ │ ├── looker-query-sql.md │ │ │ ├── looker-query-url.md │ │ │ ├── looker-query.md │ │ │ ├── looker-run-look.md │ │ │ └── looker-update-project-file.md │ │ ├── mongodb │ │ │ ├── _index.md │ │ │ ├── mongodb-aggregate.md │ │ │ ├── mongodb-delete-many.md │ │ │ ├── mongodb-delete-one.md │ │ │ ├── mongodb-find-one.md │ │ │ ├── mongodb-find.md │ │ │ ├── mongodb-insert-many.md │ │ │ ├── mongodb-insert-one.md │ │ │ ├── mongodb-update-many.md │ │ │ └── mongodb-update-one.md │ │ ├── mssql │ │ │ ├── _index.md │ │ │ ├── mssql-execute-sql.md │ │ │ ├── mssql-list-tables.md │ │ │ └── mssql-sql.md │ │ ├── mysql │ │ │ ├── _index.md │ │ │ ├── mysql-execute-sql.md │ │ │ ├── mysql-list-active-queries.md │ │ │ ├── mysql-list-table-fragmentation.md │ │ │ ├── mysql-list-tables-missing-unique-indexes.md │ │ │ ├── mysql-list-tables.md │ │ │ └── mysql-sql.md │ │ ├── neo4j │ │ │ ├── _index.md │ │ │ ├── neo4j-cypher.md │ │ │ ├── neo4j-execute-cypher.md │ │ │ └── neo4j-schema.md │ │ ├── oceanbase │ │ │ ├── _index.md │ │ │ ├── oceanbase-execute-sql.md │ │ │ └── oceanbase-sql.md │ │ ├── oracle │ │ │ ├── _index.md │ │ │ ├── oracle-execute-sql.md │ │ │ └── oracle-sql.md │ │ ├── postgres │ │ │ ├── _index.md │ │ │ ├── postgres-execute-sql.md │ │ │ ├── postgres-list-active-queries.md │ │ │ ├── postgres-list-available-extensions.md │ │ │ ├── postgres-list-installed-extensions.md │ │ │ ├── postgres-list-tables.md │ │ │ └── postgres-sql.md │ │ ├── redis │ │ │ ├── _index.md │ │ │ └── redis.md │ │ ├── spanner │ │ │ ├── _index.md │ │ │ ├── spanner-execute-sql.md │ │ │ ├── spanner-list-tables.md │ │ │ └── spanner-sql.md │ │ ├── sqlite │ │ │ ├── _index.md │ │ │ ├── sqlite-execute-sql.md │ │ │ └── sqlite-sql.md │ │ ├── tidb │ │ │ ├── _index.md │ │ │ ├── tidb-execute-sql.md │ │ │ └── tidb-sql.md │ │ ├── trino │ │ │ ├── _index.md │ │ │ ├── trino-execute-sql.md │ │ │ └── trino-sql.md │ │ ├── utility │ │ │ ├── _index.md │ │ │ └── wait.md │ │ ├── valkey │ │ │ ├── _index.md │ │ │ └── valkey.md │ │ └── yuagbytedb │ │ ├── _index.md │ │ └── yugabytedb-sql.md │ ├── samples │ │ ├── _index.md │ │ ├── alloydb │ │ │ ├── _index.md │ │ │ ├── ai-nl │ │ │ │ ├── alloydb_ai_nl.ipynb │ │ │ │ └── index.md │ │ │ └── mcp_quickstart.md │ │ ├── bigquery │ │ │ ├── _index.md │ │ │ ├── colab_quickstart_bigquery.ipynb │ │ │ ├── local_quickstart.md │ │ │ └── mcp_quickstart │ │ │ ├── _index.md │ │ │ ├── inspector_tools.png │ │ │ └── inspector.png │ │ └── looker │ │ ├── _index.md │ │ ├── looker_gemini_oauth │ │ │ ├── _index.md │ │ │ ├── authenticated.png │ │ │ ├── authorize.png │ │ │ └── registration.png │ │ ├── looker_gemini.md │ │ └── looker_mcp_inspector │ │ ├── _index.md │ │ ├── inspector_tools.png │ │ └── inspector.png │ └── sdks │ ├── _index.md │ ├── go-sdk.md │ ├── js-sdk.md │ └── python-sdk.md ├── gemini-extension.json ├── go.mod ├── go.sum ├── internal │ ├── auth │ │ ├── auth.go │ │ └── google │ │ └── google.go │ ├── log │ │ ├── handler.go │ │ ├── log_test.go │ │ ├── log.go │ │ └── logger.go │ ├── prebuiltconfigs │ │ ├── prebuiltconfigs_test.go │ │ ├── prebuiltconfigs.go │ │ └── tools │ │ ├── alloydb-postgres-admin.yaml │ │ ├── alloydb-postgres-observability.yaml │ │ ├── alloydb-postgres.yaml │ │ ├── bigquery.yaml │ │ ├── clickhouse.yaml │ │ ├── cloud-sql-mssql-admin.yaml │ │ ├── cloud-sql-mssql-observability.yaml │ │ ├── cloud-sql-mssql.yaml │ │ ├── cloud-sql-mysql-admin.yaml │ │ ├── cloud-sql-mysql-observability.yaml │ │ ├── cloud-sql-mysql.yaml │ │ ├── cloud-sql-postgres-admin.yaml │ │ ├── cloud-sql-postgres-observability.yaml │ │ ├── cloud-sql-postgres.yaml │ │ ├── dataplex.yaml │ │ ├── firestore.yaml │ │ ├── looker-conversational-analytics.yaml │ │ ├── looker.yaml │ │ ├── mssql.yaml │ │ ├── mysql.yaml │ │ ├── neo4j.yaml │ │ ├── oceanbase.yaml │ │ ├── postgres.yaml │ │ ├── spanner-postgres.yaml │ │ ├── spanner.yaml │ │ └── sqlite.yaml │ ├── server │ │ ├── api_test.go │ │ ├── api.go │ │ ├── common_test.go │ │ ├── config.go │ │ ├── mcp │ │ │ ├── jsonrpc │ │ │ │ ├── jsonrpc_test.go │ │ │ │ └── jsonrpc.go │ │ │ ├── mcp.go │ │ │ ├── util │ │ │ │ └── lifecycle.go │ │ │ ├── v20241105 │ │ │ │ ├── method.go │ │ │ │ └── types.go │ │ │ ├── v20250326 │ │ │ │ ├── method.go │ │ │ │ └── types.go │ │ │ └── v20250618 │ │ │ ├── method.go │ │ │ └── types.go │ │ ├── mcp_test.go │ │ ├── mcp.go │ │ ├── server_test.go │ │ ├── server.go │ │ ├── static │ │ │ ├── assets │ │ │ │ └── mcptoolboxlogo.png │ │ │ ├── css │ │ │ │ └── style.css │ │ │ ├── index.html │ │ │ ├── js │ │ │ │ ├── auth.js │ │ │ │ ├── loadTools.js │ │ │ │ ├── mainContent.js │ │ │ │ ├── navbar.js │ │ │ │ ├── runTool.js │ │ │ │ ├── toolDisplay.js │ │ │ │ ├── tools.js │ │ │ │ └── toolsets.js │ │ │ ├── tools.html │ │ │ └── toolsets.html │ │ ├── web_test.go │ │ └── web.go │ ├── sources │ │ ├── alloydbadmin │ │ │ ├── alloydbadmin_test.go │ │ │ └── alloydbadmin.go │ │ ├── alloydbpg │ │ │ ├── alloydb_pg_test.go │ │ │ └── alloydb_pg.go │ │ ├── bigquery │ │ │ ├── bigquery_test.go │ │ │ └── bigquery.go │ │ ├── bigtable │ │ │ ├── bigtable_test.go │ │ │ └── bigtable.go │ │ ├── cassandra │ │ │ ├── cassandra_test.go │ │ │ └── cassandra.go │ │ ├── clickhouse │ │ │ ├── clickhouse_test.go │ │ │ └── clickhouse.go │ │ ├── cloudmonitoring │ │ │ ├── cloud_monitoring_test.go │ │ │ └── cloud_monitoring.go │ │ ├── cloudsqladmin │ │ │ ├── cloud_sql_admin_test.go │ │ │ └── cloud_sql_admin.go │ │ ├── cloudsqlmssql │ │ │ ├── cloud_sql_mssql_test.go │ │ │ └── cloud_sql_mssql.go │ │ ├── cloudsqlmysql │ │ │ ├── cloud_sql_mysql_test.go │ │ │ └── cloud_sql_mysql.go │ │ ├── cloudsqlpg │ │ │ ├── cloud_sql_pg_test.go │ │ │ └── cloud_sql_pg.go │ │ ├── couchbase │ │ │ ├── couchbase_test.go │ │ │ └── couchbase.go │ │ ├── dataplex │ │ │ ├── dataplex_test.go │ │ │ └── dataplex.go │ │ ├── dgraph │ │ │ ├── dgraph_test.go │ │ │ └── dgraph.go │ │ ├── dialect.go │ │ ├── firebird │ │ │ ├── firebird_test.go │ │ │ └── firebird.go │ │ ├── firestore │ │ │ ├── firestore_test.go │ │ │ └── firestore.go │ │ ├── http │ │ │ ├── http_test.go │ │ │ └── http.go │ │ ├── ip_type.go │ │ ├── looker │ │ │ ├── looker_test.go │ │ │ └── looker.go │ │ ├── mongodb │ │ │ ├── mongodb_test.go │ │ │ └── mongodb.go │ │ ├── mssql │ │ │ ├── mssql_test.go │ │ │ └── mssql.go │ │ ├── mysql │ │ │ ├── mysql_test.go │ │ │ └── mysql.go │ │ ├── neo4j │ │ │ ├── neo4j_test.go │ │ │ └── neo4j.go │ │ ├── oceanbase │ │ │ ├── oceanbase_test.go │ │ │ └── oceanbase.go │ │ ├── oracle │ │ │ └── oracle.go │ │ ├── postgres │ │ │ ├── postgres_test.go │ │ │ └── postgres.go │ │ ├── redis │ │ │ ├── redis_test.go │ │ │ └── redis.go │ │ ├── sources.go │ │ ├── spanner │ │ │ ├── spanner_test.go │ │ │ └── spanner.go │ │ ├── sqlite │ │ │ ├── sqlite_test.go │ │ │ └── sqlite.go │ │ ├── tidb │ │ │ ├── tidb_test.go │ │ │ └── tidb.go │ │ ├── trino │ │ │ ├── trino_test.go │ │ │ └── trino.go │ │ ├── util.go │ │ ├── valkey │ │ │ ├── valkey_test.go │ │ │ └── valkey.go │ │ └── yugabytedb │ │ ├── yugabytedb_test.go │ │ └── yugabytedb.go │ ├── telemetry │ │ ├── instrumentation.go │ │ └── telemetry.go │ ├── testutils │ │ └── testutils.go │ ├── tools │ │ ├── alloydb │ │ │ ├── alloydbcreatecluster │ │ │ │ ├── alloydbcreatecluster_test.go │ │ │ │ └── alloydbcreatecluster.go │ │ │ ├── alloydbcreateinstance │ │ │ │ ├── alloydbcreateinstance_test.go │ │ │ │ └── alloydbcreateinstance.go │ │ │ ├── alloydbcreateuser │ │ │ │ ├── alloydbcreateuser_test.go │ │ │ │ └── alloydbcreateuser.go │ │ │ ├── alloydbgetcluster │ │ │ │ ├── alloydbgetcluster_test.go │ │ │ │ └── alloydbgetcluster.go │ │ │ ├── alloydbgetinstance │ │ │ │ ├── alloydbgetinstance_test.go │ │ │ │ └── alloydbgetinstance.go │ │ │ ├── alloydbgetuser │ │ │ │ ├── alloydbgetuser_test.go │ │ │ │ └── alloydbgetuser.go │ │ │ ├── alloydblistclusters │ │ │ │ ├── alloydblistclusters_test.go │ │ │ │ └── alloydblistclusters.go │ │ │ ├── alloydblistinstances │ │ │ │ ├── alloydblistinstances_test.go │ │ │ │ └── alloydblistinstances.go │ │ │ ├── alloydblistusers │ │ │ │ ├── alloydblistusers_test.go │ │ │ │ └── alloydblistusers.go │ │ │ └── alloydbwaitforoperation │ │ │ ├── alloydbwaitforoperation_test.go │ │ │ └── alloydbwaitforoperation.go │ │ ├── alloydbainl │ │ │ ├── alloydbainl_test.go │ │ │ └── alloydbainl.go │ │ ├── bigquery │ │ │ ├── bigqueryanalyzecontribution │ │ │ │ ├── bigqueryanalyzecontribution_test.go │ │ │ │ └── bigqueryanalyzecontribution.go │ │ │ ├── bigquerycommon │ │ │ │ ├── table_name_parser_test.go │ │ │ │ ├── table_name_parser.go │ │ │ │ └── util.go │ │ │ ├── bigqueryconversationalanalytics │ │ │ │ ├── bigqueryconversationalanalytics_test.go │ │ │ │ └── bigqueryconversationalanalytics.go │ │ │ ├── bigqueryexecutesql │ │ │ │ ├── bigqueryexecutesql_test.go │ │ │ │ └── bigqueryexecutesql.go │ │ │ ├── bigqueryforecast │ │ │ │ ├── bigqueryforecast_test.go │ │ │ │ └── bigqueryforecast.go │ │ │ ├── bigquerygetdatasetinfo │ │ │ │ ├── bigquerygetdatasetinfo_test.go │ │ │ │ └── bigquerygetdatasetinfo.go │ │ │ ├── bigquerygettableinfo │ │ │ │ ├── bigquerygettableinfo_test.go │ │ │ │ └── bigquerygettableinfo.go │ │ │ ├── bigquerylistdatasetids │ │ │ │ ├── bigquerylistdatasetids_test.go │ │ │ │ └── bigquerylistdatasetids.go │ │ │ ├── bigquerylisttableids │ │ │ │ ├── bigquerylisttableids_test.go │ │ │ │ └── bigquerylisttableids.go │ │ │ ├── bigquerysearchcatalog │ │ │ │ ├── bigquerysearchcatalog_test.go │ │ │ │ └── bigquerysearchcatalog.go │ │ │ └── bigquerysql │ │ │ ├── bigquerysql_test.go │ │ │ └── bigquerysql.go │ │ ├── bigtable │ │ │ ├── bigtable_test.go │ │ │ └── bigtable.go │ │ ├── cassandra │ │ │ └── cassandracql │ │ │ ├── cassandracql_test.go │ │ │ └── cassandracql.go │ │ ├── clickhouse │ │ │ ├── clickhouseexecutesql │ │ │ │ ├── clickhouseexecutesql_test.go │ │ │ │ └── clickhouseexecutesql.go │ │ │ ├── clickhouselistdatabases │ │ │ │ ├── clickhouselistdatabases_test.go │ │ │ │ └── clickhouselistdatabases.go │ │ │ ├── clickhouselisttables │ │ │ │ ├── clickhouselisttables_test.go │ │ │ │ └── clickhouselisttables.go │ │ │ └── clickhousesql │ │ │ ├── clickhousesql_test.go │ │ │ └── clickhousesql.go │ │ ├── cloudmonitoring │ │ │ ├── cloudmonitoring_test.go │ │ │ └── cloudmonitoring.go │ │ ├── cloudsql │ │ │ ├── cloudsqlcreatedatabase │ │ │ │ ├── cloudsqlcreatedatabase_test.go │ │ │ │ └── cloudsqlcreatedatabase.go │ │ │ ├── cloudsqlcreateusers │ │ │ │ ├── cloudsqlcreateusers_test.go │ │ │ │ └── cloudsqlcreateusers.go │ │ │ ├── cloudsqlgetinstances │ │ │ │ ├── cloudsqlgetinstances_test.go │ │ │ │ └── cloudsqlgetinstances.go │ │ │ ├── cloudsqllistdatabases │ │ │ │ ├── cloudsqllistdatabases_test.go │ │ │ │ └── cloudsqllistdatabases.go │ │ │ ├── cloudsqllistinstances │ │ │ │ ├── cloudsqllistinstances_test.go │ │ │ │ └── cloudsqllistinstances.go │ │ │ └── cloudsqlwaitforoperation │ │ │ ├── cloudsqlwaitforoperation_test.go │ │ │ └── cloudsqlwaitforoperation.go │ │ ├── cloudsqlmssql │ │ │ └── cloudsqlmssqlcreateinstance │ │ │ ├── cloudsqlmssqlcreateinstance_test.go │ │ │ └── cloudsqlmssqlcreateinstance.go │ │ ├── cloudsqlmysql │ │ │ └── cloudsqlmysqlcreateinstance │ │ │ ├── cloudsqlmysqlcreateinstance_test.go │ │ │ └── cloudsqlmysqlcreateinstance.go │ │ ├── cloudsqlpg │ │ │ └── cloudsqlpgcreateinstances │ │ │ ├── cloudsqlpgcreateinstances_test.go │ │ │ └── cloudsqlpgcreateinstances.go │ │ ├── common_test.go │ │ ├── common.go │ │ ├── couchbase │ │ │ ├── couchbase_test.go │ │ │ └── couchbase.go │ │ ├── dataform │ │ │ └── dataformcompilelocal │ │ │ ├── dataformcompilelocal_test.go │ │ │ └── dataformcompilelocal.go │ │ ├── dataplex │ │ │ ├── dataplexlookupentry │ │ │ │ ├── dataplexlookupentry_test.go │ │ │ │ └── dataplexlookupentry.go │ │ │ ├── dataplexsearchaspecttypes │ │ │ │ ├── dataplexsearchaspecttypes_test.go │ │ │ │ └── dataplexsearchaspecttypes.go │ │ │ └── dataplexsearchentries │ │ │ ├── dataplexsearchentries_test.go │ │ │ └── dataplexsearchentries.go │ │ ├── dgraph │ │ │ ├── dgraph_test.go │ │ │ └── dgraph.go │ │ ├── firebird │ │ │ ├── firebirdexecutesql │ │ │ │ ├── firebirdexecutesql_test.go │ │ │ │ └── firebirdexecutesql.go │ │ │ └── firebirdsql │ │ │ ├── firebirdsql_test.go │ │ │ └── firebirdsql.go │ │ ├── firestore │ │ │ ├── firestoreadddocuments │ │ │ │ ├── firestoreadddocuments_test.go │ │ │ │ └── firestoreadddocuments.go │ │ │ ├── firestoredeletedocuments │ │ │ │ ├── firestoredeletedocuments_test.go │ │ │ │ └── firestoredeletedocuments.go │ │ │ ├── firestoregetdocuments │ │ │ │ ├── firestoregetdocuments_test.go │ │ │ │ └── firestoregetdocuments.go │ │ │ ├── firestoregetrules │ │ │ │ ├── firestoregetrules_test.go │ │ │ │ └── firestoregetrules.go │ │ │ ├── firestorelistcollections │ │ │ │ ├── firestorelistcollections_test.go │ │ │ │ └── firestorelistcollections.go │ │ │ ├── firestorequery │ │ │ │ ├── firestorequery_test.go │ │ │ │ └── firestorequery.go │ │ │ ├── firestorequerycollection │ │ │ │ ├── firestorequerycollection_test.go │ │ │ │ └── firestorequerycollection.go │ │ │ ├── firestoreupdatedocument │ │ │ │ ├── firestoreupdatedocument_test.go │ │ │ │ └── firestoreupdatedocument.go │ │ │ ├── firestorevalidaterules │ │ │ │ ├── firestorevalidaterules_test.go │ │ │ │ └── firestorevalidaterules.go │ │ │ └── util │ │ │ ├── converter_test.go │ │ │ ├── converter.go │ │ │ ├── validator_test.go │ │ │ └── validator.go │ │ ├── http │ │ │ ├── http_test.go │ │ │ └── http.go │ │ ├── http_method.go │ │ ├── looker │ │ │ ├── lookeradddashboardelement │ │ │ │ ├── lookeradddashboardelement_test.go │ │ │ │ └── lookeradddashboardelement.go │ │ │ ├── lookercommon │ │ │ │ ├── lookercommon_test.go │ │ │ │ └── lookercommon.go │ │ │ ├── lookerconversationalanalytics │ │ │ │ ├── lookerconversationalanalytics_test.go │ │ │ │ └── lookerconversationalanalytics.go │ │ │ ├── lookercreateprojectfile │ │ │ │ ├── lookercreateprojectfile_test.go │ │ │ │ └── lookercreateprojectfile.go │ │ │ ├── lookerdeleteprojectfile │ │ │ │ ├── lookerdeleteprojectfile_test.go │ │ │ │ └── lookerdeleteprojectfile.go │ │ │ ├── lookerdevmode │ │ │ │ ├── lookerdevmode_test.go │ │ │ │ └── lookerdevmode.go │ │ │ ├── lookergetdashboards │ │ │ │ ├── lookergetdashboards_test.go │ │ │ │ └── lookergetdashboards.go │ │ │ ├── lookergetdimensions │ │ │ │ ├── lookergetdimensions_test.go │ │ │ │ └── lookergetdimensions.go │ │ │ ├── lookergetexplores │ │ │ │ ├── lookergetexplores_test.go │ │ │ │ └── lookergetexplores.go │ │ │ ├── lookergetfilters │ │ │ │ ├── lookergetfilters_test.go │ │ │ │ └── lookergetfilters.go │ │ │ ├── lookergetlooks │ │ │ │ ├── lookergetlooks_test.go │ │ │ │ └── lookergetlooks.go │ │ │ ├── lookergetmeasures │ │ │ │ ├── lookergetmeasures_test.go │ │ │ │ └── lookergetmeasures.go │ │ │ ├── lookergetmodels │ │ │ │ ├── lookergetmodels_test.go │ │ │ │ └── lookergetmodels.go │ │ │ ├── lookergetparameters │ │ │ │ ├── lookergetparameters_test.go │ │ │ │ └── lookergetparameters.go │ │ │ ├── lookergetprojectfile │ │ │ │ ├── lookergetprojectfile_test.go │ │ │ │ └── lookergetprojectfile.go │ │ │ ├── lookergetprojectfiles │ │ │ │ ├── lookergetprojectfiles_test.go │ │ │ │ └── lookergetprojectfiles.go │ │ │ ├── lookergetprojects │ │ │ │ ├── lookergetprojects_test.go │ │ │ │ └── lookergetprojects.go │ │ │ ├── lookerhealthanalyze │ │ │ │ ├── lookerhealthanalyze_test.go │ │ │ │ └── lookerhealthanalyze.go │ │ │ ├── lookerhealthpulse │ │ │ │ ├── lookerhealthpulse_test.go │ │ │ │ └── lookerhealthpulse.go │ │ │ ├── lookerhealthvacuum │ │ │ │ ├── lookerhealthvacuum_test.go │ │ │ │ └── lookerhealthvacuum.go │ │ │ ├── lookermakedashboard │ │ │ │ ├── lookermakedashboard_test.go │ │ │ │ └── lookermakedashboard.go │ │ │ ├── lookermakelook │ │ │ │ ├── lookermakelook_test.go │ │ │ │ └── lookermakelook.go │ │ │ ├── lookerquery │ │ │ │ ├── lookerquery_test.go │ │ │ │ └── lookerquery.go │ │ │ ├── lookerquerysql │ │ │ │ ├── lookerquerysql_test.go │ │ │ │ └── lookerquerysql.go │ │ │ ├── lookerqueryurl │ │ │ │ ├── lookerqueryurl_test.go │ │ │ │ └── lookerqueryurl.go │ │ │ ├── lookerrunlook │ │ │ │ ├── lookerrunlook_test.go │ │ │ │ └── lookerrunlook.go │ │ │ └── lookerupdateprojectfile │ │ │ ├── lookerupdateprojectfile_test.go │ │ │ └── lookerupdateprojectfile.go │ │ ├── mongodb │ │ │ ├── mongodbaggregate │ │ │ │ ├── mongodbaggregate_test.go │ │ │ │ └── mongodbaggregate.go │ │ │ ├── mongodbdeletemany │ │ │ │ ├── mongodbdeletemany_test.go │ │ │ │ └── mongodbdeletemany.go │ │ │ ├── mongodbdeleteone │ │ │ │ ├── mongodbdeleteone_test.go │ │ │ │ └── mongodbdeleteone.go │ │ │ ├── mongodbfind │ │ │ │ ├── mongodbfind_test.go │ │ │ │ └── mongodbfind.go │ │ │ ├── mongodbfindone │ │ │ │ ├── mongodbfindone_test.go │ │ │ │ └── mongodbfindone.go │ │ │ ├── mongodbinsertmany │ │ │ │ ├── mongodbinsertmany_test.go │ │ │ │ └── mongodbinsertmany.go │ │ │ ├── mongodbinsertone │ │ │ │ ├── mongodbinsertone_test.go │ │ │ │ └── mongodbinsertone.go │ │ │ ├── mongodbupdatemany │ │ │ │ ├── mongodbupdatemany_test.go │ │ │ │ └── mongodbupdatemany.go │ │ │ └── mongodbupdateone │ │ │ ├── mongodbupdateone_test.go │ │ │ └── mongodbupdateone.go │ │ ├── mssql │ │ │ ├── mssqlexecutesql │ │ │ │ ├── mssqlexecutesql_test.go │ │ │ │ └── mssqlexecutesql.go │ │ │ ├── mssqllisttables │ │ │ │ ├── mssqllisttables_test.go │ │ │ │ └── mssqllisttables.go │ │ │ └── mssqlsql │ │ │ ├── mssqlsql_test.go │ │ │ └── mssqlsql.go │ │ ├── mysql │ │ │ ├── mysqlcommon │ │ │ │ └── mysqlcommon.go │ │ │ ├── mysqlexecutesql │ │ │ │ ├── mysqlexecutesql_test.go │ │ │ │ └── mysqlexecutesql.go │ │ │ ├── mysqllistactivequeries │ │ │ │ ├── mysqllistactivequeries_test.go │ │ │ │ └── mysqllistactivequeries.go │ │ │ ├── mysqllisttablefragmentation │ │ │ │ ├── mysqllisttablefragmentation_test.go │ │ │ │ └── mysqllisttablefragmentation.go │ │ │ ├── mysqllisttables │ │ │ │ ├── mysqllisttables_test.go │ │ │ │ └── mysqllisttables.go │ │ │ ├── mysqllisttablesmissinguniqueindexes │ │ │ │ ├── mysqllisttablesmissinguniqueindexes_test.go │ │ │ │ └── mysqllisttablesmissinguniqueindexes.go │ │ │ └── mysqlsql │ │ │ ├── mysqlsql_test.go │ │ │ └── mysqlsql.go │ │ ├── neo4j │ │ │ ├── neo4jcypher │ │ │ │ ├── neo4jcypher_test.go │ │ │ │ └── neo4jcypher.go │ │ │ ├── neo4jexecutecypher │ │ │ │ ├── classifier │ │ │ │ │ ├── classifier_test.go │ │ │ │ │ └── classifier.go │ │ │ │ ├── neo4jexecutecypher_test.go │ │ │ │ └── neo4jexecutecypher.go │ │ │ └── neo4jschema │ │ │ ├── cache │ │ │ │ ├── cache_test.go │ │ │ │ └── cache.go │ │ │ ├── helpers │ │ │ │ ├── helpers_test.go │ │ │ │ └── helpers.go │ │ │ ├── neo4jschema_test.go │ │ │ ├── neo4jschema.go │ │ │ └── types │ │ │ └── types.go │ │ ├── oceanbase │ │ │ ├── oceanbaseexecutesql │ │ │ │ ├── oceanbaseexecutesql_test.go │ │ │ │ └── oceanbaseexecutesql.go │ │ │ └── oceanbasesql │ │ │ ├── oceanbasesql_test.go │ │ │ └── oceanbasesql.go │ │ ├── oracle │ │ │ ├── oracleexecutesql │ │ │ │ └── oracleexecutesql.go │ │ │ └── oraclesql │ │ │ └── oraclesql.go │ │ ├── parameters_test.go │ │ ├── parameters.go │ │ ├── postgres │ │ │ ├── postgresexecutesql │ │ │ │ ├── postgresexecutesql_test.go │ │ │ │ └── postgresexecutesql.go │ │ │ ├── postgreslistactivequeries │ │ │ │ ├── postgreslistactivequeries_test.go │ │ │ │ └── postgreslistactivequeries.go │ │ │ ├── postgreslistavailableextensions │ │ │ │ ├── postgreslistavailableextensions_test.go │ │ │ │ └── postgreslistavailableextensions.go │ │ │ ├── postgreslistinstalledextensions │ │ │ │ ├── postgreslistinstalledextensions_test.go │ │ │ │ └── postgreslistinstalledextensions.go │ │ │ ├── postgreslisttables │ │ │ │ ├── postgreslisttables_test.go │ │ │ │ └── postgreslisttables.go │ │ │ └── postgressql │ │ │ ├── postgressql_test.go │ │ │ └── postgressql.go │ │ ├── redis │ │ │ ├── redis_test.go │ │ │ └── redis.go │ │ ├── spanner │ │ │ ├── spannerexecutesql │ │ │ │ ├── spannerexecutesql_test.go │ │ │ │ └── spannerexecutesql.go │ │ │ ├── spannerlisttables │ │ │ │ ├── spannerlisttables_test.go │ │ │ │ └── spannerlisttables.go │ │ │ └── spannersql │ │ │ ├── spanner_test.go │ │ │ └── spannersql.go │ │ ├── sqlite │ │ │ ├── sqliteexecutesql │ │ │ │ ├── sqliteexecutesql_test.go │ │ │ │ └── sqliteexecutesql.go │ │ │ └── sqlitesql │ │ │ ├── sqlitesql_test.go │ │ │ └── sqlitesql.go │ │ ├── tidb │ │ │ ├── tidbexecutesql │ │ │ │ ├── tidbexecutesql_test.go │ │ │ │ └── tidbexecutesql.go │ │ │ └── tidbsql │ │ │ ├── tidbsql_test.go │ │ │ └── tidbsql.go │ │ ├── tools_test.go │ │ ├── tools.go │ │ ├── toolsets.go │ │ ├── trino │ │ │ ├── trinoexecutesql │ │ │ │ ├── trinoexecutesql_test.go │ │ │ │ └── trinoexecutesql.go │ │ │ └── trinosql │ │ │ ├── trinosql_test.go │ │ │ └── trinosql.go │ │ ├── utility │ │ │ └── wait │ │ │ ├── wait_test.go │ │ │ └── wait.go │ │ ├── valkey │ │ │ ├── valkey_test.go │ │ │ └── valkey.go │ │ └── yugabytedbsql │ │ ├── yugabytedbsql_test.go │ │ └── yugabytedbsql.go │ └── util │ └── util.go ├── LICENSE ├── logo.png ├── main.go ├── MCP-TOOLBOX-EXTENSION.md ├── README.md └── tests ├── alloydb │ ├── alloydb_integration_test.go │ └── alloydb_wait_for_operation_test.go ├── alloydbainl │ └── alloydb_ai_nl_integration_test.go ├── alloydbpg │ └── alloydb_pg_integration_test.go ├── auth.go ├── bigquery │ └── bigquery_integration_test.go ├── bigtable │ └── bigtable_integration_test.go ├── cassandra │ └── cassandra_integration_test.go ├── clickhouse │ └── clickhouse_integration_test.go ├── cloudmonitoring │ └── cloud_monitoring_integration_test.go ├── cloudsql │ ├── cloud_sql_create_database_test.go │ ├── cloud_sql_create_users_test.go │ ├── cloud_sql_get_instances_test.go │ ├── cloud_sql_list_databases_test.go │ ├── cloudsql_list_instances_test.go │ └── cloudsql_wait_for_operation_test.go ├── cloudsqlmssql │ ├── cloud_sql_mssql_create_instance_integration_test.go │ └── cloud_sql_mssql_integration_test.go ├── cloudsqlmysql │ ├── cloud_sql_mysql_create_instance_integration_test.go │ └── cloud_sql_mysql_integration_test.go ├── cloudsqlpg │ ├── cloud_sql_pg_create_instances_test.go │ └── cloud_sql_pg_integration_test.go ├── common.go ├── couchbase │ └── couchbase_integration_test.go ├── dataform │ └── dataform_integration_test.go ├── dataplex │ └── dataplex_integration_test.go ├── dgraph │ └── dgraph_integration_test.go ├── firebird │ └── firebird_integration_test.go ├── firestore │ └── firestore_integration_test.go ├── http │ └── http_integration_test.go ├── looker │ └── looker_integration_test.go ├── mongodb │ └── mongodb_integration_test.go ├── mssql │ └── mssql_integration_test.go ├── mysql │ └── mysql_integration_test.go ├── neo4j │ └── neo4j_integration_test.go ├── oceanbase │ └── oceanbase_integration_test.go ├── option.go ├── oracle │ └── oracle_integration_test.go ├── postgres │ └── postgres_integration_test.go ├── redis │ └── redis_test.go ├── server.go ├── source.go ├── spanner │ └── spanner_integration_test.go ├── sqlite │ └── sqlite_integration_test.go ├── tidb │ └── tidb_integration_test.go ├── tool.go ├── trino │ └── trino_integration_test.go ├── utility │ └── wait_integration_test.go ├── valkey │ └── valkey_test.go └── yugabytedb └── yugabytedb_integration_test.go ``` # Files -------------------------------------------------------------------------------- /tests/cassandra/cassandra_integration_test.go: -------------------------------------------------------------------------------- ```go // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cassandra import ( "context" "fmt" "log" "os" "regexp" "strings" "testing" "time" "github.com/gocql/gocql" "github.com/google/uuid" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/tests" ) var ( CassandraSourceKind = "cassandra" CassandraToolKind = "cassandra-cql" Hosts = os.Getenv("CASSANDRA_HOST") Keyspace = "example_keyspace" Username = os.Getenv("CASSANDRA_USER") Password = os.Getenv("CASSANDRA_PASS") ) func getCassandraVars(t *testing.T) map[string]any { switch "" { case Hosts: t.Fatal("'Hosts' not set") case Username: t.Fatal("'Username' not set") case Password: t.Fatal("'Password' not set") } return map[string]any{ "kind": CassandraSourceKind, "hosts": strings.Split(Hosts, ","), "keyspace": Keyspace, "username": Username, "password": Password, } } func initCassandraSession() (*gocql.Session, error) { hostStrings := strings.Split(Hosts, ",") var hosts []string for _, h := range hostStrings { trimmedHost := strings.TrimSpace(h) if trimmedHost != "" { hosts = append(hosts, trimmedHost) } } if len(hosts) == 0 { return nil, fmt.Errorf("no valid hosts found in CASSANDRA_HOSTS env var") } // Configure cluster connection cluster := gocql.NewCluster(hosts...) cluster.Consistency = gocql.Quorum cluster.ProtoVersion = 4 cluster.DisableInitialHostLookup = true cluster.ConnectTimeout = 10 * time.Second cluster.NumConns = 2 cluster.Authenticator = gocql.PasswordAuthenticator{ Username: Username, Password: Password, } cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{ NumRetries: 3, Min: 200 * time.Millisecond, Max: 2 * time.Second, } // Create session session, err := cluster.CreateSession() if err != nil { return nil, fmt.Errorf("Failed to create session: %v", err) } // Create keyspace err = session.Query(fmt.Sprintf(` CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} `, Keyspace)).Exec() if err != nil { return nil, fmt.Errorf("Failed to create keyspace: %v", err) } return session, nil } func initTable(tableName string, session *gocql.Session) error { // Create table with additional columns err := session.Query(fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s.%s ( id int PRIMARY KEY, name text, email text, age int, is_active boolean, created_at timestamp ) `, Keyspace, tableName)).Exec() if err != nil { return fmt.Errorf("Failed to create table: %v", err) } // Use fixed timestamps for reproducibility fixedTime, _ := time.Parse(time.RFC3339, "2025-07-25T12:00:00Z") dayAgo := fixedTime.Add(-24 * time.Hour) twelveHoursAgo := fixedTime.Add(-12 * time.Hour) // Insert minimal diverse data with fixed time.Time for timestamps err = session.Query(fmt.Sprintf(` INSERT INTO %s.%s (id, name,email, age, is_active, created_at) VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName), 3, "Alice", tests.ServiceAccountEmail, 25, true, dayAgo, ).Exec() if err != nil { return fmt.Errorf("Failed to insert user: %v", err) } err = session.Query(fmt.Sprintf(` INSERT INTO %s.%s (id, name,email, age, is_active, created_at) VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName), 2, "Alex", "[email protected]", 30, false, twelveHoursAgo, ).Exec() if err != nil { return fmt.Errorf("Failed to insert user: %v", err) } err = session.Query(fmt.Sprintf(` INSERT INTO %s.%s (id, name,email, age, is_active, created_at) VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName), 1, "Sid", "[email protected]", 10, true, fixedTime, ).Exec() if err != nil { return fmt.Errorf("Failed to insert user: %v", err) } err = session.Query(fmt.Sprintf(` INSERT INTO %s.%s (id, name,email, age, is_active, created_at) VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName), 4, nil, "[email protected]", 40, false, fixedTime, ).Exec() if err != nil { return fmt.Errorf("Failed to insert user: %v", err) } return nil } func dropTable(session *gocql.Session, tableName string) { err := session.Query(fmt.Sprintf("drop table %s.%s", Keyspace, tableName)).Exec() if err != nil { log.Printf("Failed to drop table %s: %v", tableName, err) } } func TestCassandra(t *testing.T) { session, err := initCassandraSession() if err != nil { t.Fatal(err) } defer session.Close() sourceConfig := getCassandraVars(t) ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() var args []string paramTableName := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") err = initTable(paramTableName, session) if err != nil { t.Fatal(err) } defer dropTable(session, paramTableName) err = initTable(tableNameAuth, session) if err != nil { t.Fatal(err) } defer dropTable(session, tableNameAuth) err = initTable(tableNameTemplateParam, session) if err != nil { t.Fatal(err) } defer dropTable(session, tableNameTemplateParam) paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt := createParamToolInfo(paramTableName) _, _, authToolStmt := getCassandraAuthToolInfo(tableNameAuth) toolsFile := tests.GetToolsConfig(sourceConfig, CassandraToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) tmplSelectCombined, tmplSelectFilterCombined := getCassandraTmplToolInfo() tmpSelectAll := "SELECT * FROM {{.tableName}} where id = 1" toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CassandraToolKind, tmplSelectCombined, tmplSelectFilterCombined, tmpSelectAll) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) if err != nil { t.Fatalf("command initialization returned an error: %s", err) } defer cleanup() waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) if err != nil { t.Logf("toolbox command logs: \n%s", out) t.Fatalf("toolbox didn't start successfully: %s", err) } selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, mcpSelect1Want, mcpMyToolIdWant := getCassandraWants() selectAllWant, selectIdWant, selectNameWant := getCassandraTmplWants() tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, "", tests.DisableSelect1Test(), tests.DisableOptionalNullParamTest(), tests.WithMyToolId3NameAliceWant(selectIdNameWant), tests.WithMyToolById4Want(selectIdNullWant), tests.WithMyArrayToolWant(selectArrayParamWant), tests.DisableSelect1AuthTest()) tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.DisableSelectFilterTest(), tests.WithSelectAllWant(selectAllWant), tests.DisableDdlTest(), tests.DisableInsertTest(), tests.WithTmplSelectId1Want(selectIdWant), tests.WithTmplSelectNameWant(selectNameWant)) tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want, tests.WithMcpMyToolId3NameAliceWant(mcpMyToolIdWant), tests.DisableMcpSelect1AuthTest()) } func createParamToolInfo(tableName string) (string, string, string, string) { toolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id = ? AND name = ? ALLOW FILTERING;", tableName) idParamStatement := fmt.Sprintf("SELECT id,name FROM %s WHERE id = ?;", tableName) nameParamStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE name = ? ALLOW FILTERING;", tableName) arrayToolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id IN ? AND name IN ? ALLOW FILTERING;", tableName) return toolStatement, idParamStatement, nameParamStatement, arrayToolStatement } func getCassandraAuthToolInfo(tableName string) (string, string, string) { createStatement := fmt.Sprintf("CREATE TABLE %s (id UUID PRIMARY KEY, name TEXT, email TEXT);", tableName) insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (uuid(), ?, ?), (uuid(), ?, ?);", tableName) toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ? ALLOW FILTERING;", tableName) return createStatement, insertStatement, toolStatement } func getCassandraTmplToolInfo() (string, string) { selectAllTemplateStmt := "SELECT age, id, name FROM {{.tableName}} where id = ?;" selectByIdTemplateStmt := "SELECT id, name FROM {{.tableName}} WHERE name = ? ALLOW FILTERING;" return selectAllTemplateStmt, selectByIdTemplateStmt } func getCassandraWants() (string, string, string, string, string, string) { selectIdNameWant := "[{\"id\":3,\"name\":\"Alice\"}]" selectIdNullWant := "[{\"id\":4,\"name\":\"\"}]" selectArrayParamWant := "[{\"id\":1,\"name\":\"Sid\"},{\"id\":3,\"name\":\"Alice\"}]" mcpMyFailToolWant := "{\"jsonrpc\":\"2.0\",\"id\":\"invoke-fail-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"unable to parse rows: line 1:0 no viable alternative at input 'SELEC' ([SELEC]...)\"}],\"isError\":true}}" mcpMyToolIdWant := "{\"jsonrpc\":\"2.0\",\"id\":\"my-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"[{\\\"id\\\":3,\\\"name\\\":\\\"Alice\\\"}]\"}]}}" return selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, "nil", mcpMyToolIdWant } func getCassandraTmplWants() (string, string, string) { selectAllWant := "[{\"age\":10,\"created_at\":\"2025-07-25T12:00:00Z\",\"email\":\"[email protected]\",\"id\":1,\"is_active\":true,\"name\":\"Sid\"}]" selectIdWant := "[{\"age\":10,\"id\":1,\"name\":\"Sid\"}]" selectNameWant := "[{\"id\":2,\"name\":\"Alex\"}]" return selectAllWant, selectIdWant, selectNameWant } ``` -------------------------------------------------------------------------------- /tests/couchbase/couchbase_integration_test.go: -------------------------------------------------------------------------------- ```go // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package couchbase import ( "context" "fmt" "os" "regexp" "strings" "testing" "time" "github.com/couchbase/gocb/v2" "github.com/google/uuid" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/tests" ) const ( couchbaseSourceKind = "couchbase" couchbaseToolKind = "couchbase-sql" ) var ( couchbaseConnection = os.Getenv("COUCHBASE_CONNECTION") couchbaseBucket = os.Getenv("COUCHBASE_BUCKET") couchbaseScope = os.Getenv("COUCHBASE_SCOPE") couchbaseUser = os.Getenv("COUCHBASE_USER") couchbasePass = os.Getenv("COUCHBASE_PASS") ) // getCouchbaseVars validates and returns Couchbase configuration variables func getCouchbaseVars(t *testing.T) map[string]any { switch "" { case couchbaseConnection: t.Fatal("'COUCHBASE_CONNECTION' not set") case couchbaseBucket: t.Fatal("'COUCHBASE_BUCKET' not set") case couchbaseScope: t.Fatal("'COUCHBASE_SCOPE' not set") case couchbaseUser: t.Fatal("'COUCHBASE_USER' not set") case couchbasePass: t.Fatal("'COUCHBASE_PASS' not set") } return map[string]any{ "kind": couchbaseSourceKind, "connectionString": couchbaseConnection, "bucket": couchbaseBucket, "scope": couchbaseScope, "username": couchbaseUser, "password": couchbasePass, "queryScanConsistency": 2, } } // initCouchbaseCluster initializes a connection to the Couchbase cluster func initCouchbaseCluster(connectionString, username, password string) (*gocb.Cluster, error) { opts := gocb.ClusterOptions{ Authenticator: gocb.PasswordAuthenticator{ Username: username, Password: password, }, } cluster, err := gocb.Connect(connectionString, opts) if err != nil { return nil, fmt.Errorf("gocb.Connect: %w", err) } return cluster, nil } func TestCouchbaseToolEndpoints(t *testing.T) { sourceConfig := getCouchbaseVars(t) ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() var args []string cluster, err := initCouchbaseCluster(couchbaseConnection, couchbaseUser, couchbasePass) if err != nil { t.Fatalf("unable to create Couchbase connection: %s", err) } defer cluster.Close(nil) // Create collection names with UUID collectionNameParam := "param_" + strings.ReplaceAll(uuid.New().String(), "-", "") collectionNameAuth := "auth_" + strings.ReplaceAll(uuid.New().String(), "-", "") collectionNameTemplateParam := "template_param_" + strings.ReplaceAll(uuid.New().String(), "-", "") // Set up data for param tool paramToolStatement, idParamToolStmt, nameParamToolStmt, arrayToolStatement, paramTestParams := getCouchbaseParamToolInfo(collectionNameParam) teardownCollection1 := setupCouchbaseCollection(t, ctx, cluster, couchbaseBucket, couchbaseScope, collectionNameParam, paramTestParams) defer teardownCollection1(t) // Set up data for auth tool authToolStatement, authTestParams := getCouchbaseAuthToolInfo(collectionNameAuth) teardownCollection2 := setupCouchbaseCollection(t, ctx, cluster, couchbaseBucket, couchbaseScope, collectionNameAuth, authTestParams) defer teardownCollection2(t) // Setup up table for template param tool tmplSelectCombined, tmplSelectFilterCombined, tmplSelectAll, params3 := getCouchbaseTemplateParamToolInfo() teardownCollection3 := setupCouchbaseCollection(t, ctx, cluster, couchbaseBucket, couchbaseScope, collectionNameTemplateParam, params3) defer teardownCollection3(t) // Write config into a file and pass it to command toolsFile := tests.GetToolsConfig(sourceConfig, couchbaseToolKind, paramToolStatement, idParamToolStmt, nameParamToolStmt, arrayToolStatement, authToolStatement) toolsFile = tests.AddTemplateParamConfig(t, toolsFile, couchbaseToolKind, tmplSelectCombined, tmplSelectFilterCombined, tmplSelectAll) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) if err != nil { t.Fatalf("command initialization returned an error: %s", err) } defer cleanup() waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) if err != nil { t.Logf("toolbox command logs: \n%s", out) t.Fatalf("toolbox didn't start successfully: %s", err) } // Get configs for tests select1Want := "[{\"$1\":1}]" mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: parsing failure | {\"statement\":\"SELEC 1;\"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"$1\":1}"}]}}` tmplSelectId1Want := "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]" selectAllWant := "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]" // Run tests tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, select1Want) tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) tests.RunToolInvokeWithTemplateParameters(t, collectionNameTemplateParam, tests.WithTmplSelectId1Want(tmplSelectId1Want), tests.WithSelectAllWant(selectAllWant), tests.DisableDdlTest(), tests.DisableInsertTest(), ) } // setupCouchbaseCollection creates a scope and collection and inserts test data func setupCouchbaseCollection(t *testing.T, ctx context.Context, cluster *gocb.Cluster, bucketName, scopeName, collectionName string, params []map[string]any) func(t *testing.T) { // Get bucket reference bucket := cluster.Bucket(bucketName) // Wait for bucket to be ready err := bucket.WaitUntilReady(5*time.Second, nil) if err != nil { t.Fatalf("failed to connect to bucket: %v", err) } // Create scope if it doesn't exist bucketMgr := bucket.CollectionsV2() err = bucketMgr.CreateScope(scopeName, nil) if err != nil && !strings.Contains(err.Error(), "already exists") { t.Logf("failed to create scope (might already exist): %v", err) } // Create a collection if it doesn't exist err = bucketMgr.CreateCollection(scopeName, collectionName, nil, nil) if err != nil && !strings.Contains(err.Error(), "already exists") { t.Fatalf("failed to create collection: %v", err) } // Get a reference to the collection collection := bucket.Scope(scopeName).Collection(collectionName) // Create primary index if it doesn't exist // Create primary index with retry logic maxRetries := 5 retryDelay := 50 * time.Millisecond actualRetries := 0 var lastErr error for attempt := 0; attempt < maxRetries; attempt++ { err = collection.QueryIndexes().CreatePrimaryIndex( &gocb.CreatePrimaryQueryIndexOptions{ IgnoreIfExists: true, }) if err == nil { lastErr = err // clear previous error break } lastErr = err t.Logf("Attempt %d: failed to create primary index: %v, retrying in %v", attempt+1, err, retryDelay) time.Sleep(retryDelay) // Exponential backoff retryDelay *= 2 actualRetries += 1 } if lastErr != nil { t.Fatalf("failed to create primary index collection after %d attempts: %v", actualRetries, lastErr) } // Insert test documents for i, param := range params { _, err = collection.Upsert(fmt.Sprintf("%d", i+1), param, &gocb.UpsertOptions{ DurabilityLevel: gocb.DurabilityLevelMajority, }) if err != nil { t.Fatalf("failed to insert test data: %v", err) } } // Return a cleanup function return func(t *testing.T) { // Drop the collection err := bucketMgr.DropCollection(scopeName, collectionName, nil) if err != nil { t.Logf("failed to drop collection: %v", err) } } } // getCouchbaseParamToolInfo returns statements and params for my-tool couchbase-sql kind func getCouchbaseParamToolInfo(collectionName string) (string, string, string, string, []map[string]any) { // N1QL uses positional or named parameters with $ prefix toolStatement := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+ "%s.* FROM %s WHERE meta().id = TOSTRING($id) OR name = $name order by meta().id", collectionName, collectionName) idToolStatement := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+ "%s.* FROM %s WHERE meta().id = TOSTRING($id) order by meta().id", collectionName, collectionName) nameToolStatement := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+ "%s.* FROM %s WHERE name = $name order by meta().id", collectionName, collectionName) arrayToolStatemnt := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+ "%s.* FROM %s WHERE TONUMBER(meta().id) IN $idArray AND name IN $nameArray order by meta().id", collectionName, collectionName) params := []map[string]any{ {"name": "Alice"}, {"name": "Jane"}, {"name": "Sid"}, {"name": nil}, } return toolStatement, idToolStatement, nameToolStatement, arrayToolStatemnt, params } // getCouchbaseAuthToolInfo returns statements and param of my-auth-tool for couchbase-sql kind func getCouchbaseAuthToolInfo(collectionName string) (string, []map[string]any) { toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = $email", collectionName) params := []map[string]any{ {"name": "Alice", "email": tests.ServiceAccountEmail}, {"name": "Jane", "email": "[email protected]"}, } return toolStatement, params } func getCouchbaseTemplateParamToolInfo() (string, string, string, []map[string]any) { tmplSelectCombined := "SELECT {{.tableName}}.* FROM {{.tableName}} WHERE id = $id" tmplSelectFilterCombined := "SELECT {{.tableName}}.* FROM {{.tableName}} WHERE {{.columnFilter}} = $name" tmplSelectAll := "SELECT {{.tableName}}.* FROM {{.tableName}}" params := []map[string]any{ {"name": "Alex", "id": 1, "age": 21}, {"name": "Alice", "id": 2, "age": 100}, } return tmplSelectCombined, tmplSelectFilterCombined, tmplSelectAll, params } ``` -------------------------------------------------------------------------------- /internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go: -------------------------------------------------------------------------------- ```go // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package firestoreupdatedocument import ( "context" "fmt" "strings" firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" ) const kind string = "firestore-update-document" const documentPathKey string = "documentPath" const documentDataKey string = "documentData" const updateMaskKey string = "updateMask" const returnDocumentDataKey string = "returnData" func init() { if !tools.Register(kind, newConfig) { panic(fmt.Sprintf("tool kind %q already registered", kind)) } } func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { actual := Config{Name: name} if err := decoder.DecodeContext(ctx, &actual); err != nil { return nil, err } return actual, nil } type compatibleSource interface { FirestoreClient() *firestoreapi.Client } // validate compatible sources are still compatible var _ compatibleSource = &firestoreds.Source{} var compatibleSources = [...]string{firestoreds.SourceKind} type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` Source string `yaml:"source" validate:"required"` Description string `yaml:"description" validate:"required"` AuthRequired []string `yaml:"authRequired"` } // validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigKind() string { return kind } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { // verify source exists rawS, ok := srcs[cfg.Source] if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) } // Create parameters documentPathParameter := tools.NewStringParameter( documentPathKey, "The relative path of the document which needs to be updated (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: This is a relative path, NOT an absolute path like 'projects/{project_id}/databases/{database_id}/documents/...'", ) documentDataParameter := tools.NewMapParameter( documentDataKey, `The document data in Firestore's native JSON format. Each field must be wrapped with a type indicator: - Strings: {"stringValue": "text"} - Integers: {"integerValue": "123"} or {"integerValue": 123} - Doubles: {"doubleValue": 123.45} - Booleans: {"booleanValue": true} - Timestamps: {"timestampValue": "2025-01-07T10:00:00Z"} - GeoPoints: {"geoPointValue": {"latitude": 34.05, "longitude": -118.24}} - Arrays: {"arrayValue": {"values": [{"stringValue": "item1"}, {"integerValue": "2"}]}} - Maps: {"mapValue": {"fields": {"key1": {"stringValue": "value1"}, "key2": {"booleanValue": true}}}} - Null: {"nullValue": null} - Bytes: {"bytesValue": "base64EncodedString"} - References: {"referenceValue": "collection/document"}`, "", // Empty string for generic map that accepts any value type ) updateMaskParameter := tools.NewArrayParameterWithRequired( updateMaskKey, "The selective fields to update. If not provided, all fields in documentData will be updated. When provided, only the specified fields will be updated. Fields referenced in the mask but not present in documentData will be deleted from the document", false, // not required tools.NewStringParameter("field", "Field path to update or delete. Use dot notation to access nested fields within maps (e.g., 'address.city' to update the city field within an address map, or 'user.profile.name' for deeply nested fields). To delete a field, include it in the mask but omit it from documentData. Note: You cannot update individual array elements; you must update the entire array field"), ) returnDataParameter := tools.NewBooleanParameterWithDefault( returnDocumentDataKey, false, "If set to true the output will have the data of the updated document. This flag if set to false will help avoid overloading the context of the agent.", ) parameters := tools.Parameters{ documentPathParameter, documentDataParameter, updateMaskParameter, returnDataParameter, } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) // finish tool setup t := Tool{ Name: cfg.Name, Kind: kind, Parameters: parameters, AuthRequired: cfg.AuthRequired, Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } return t, nil } // validate interface var _ tools.Tool = Tool{} type Tool struct { Name string `yaml:"name"` Kind string `yaml:"kind"` AuthRequired []string `yaml:"authRequired"` Parameters tools.Parameters `yaml:"parameters"` Client *firestoreapi.Client manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { mapParams := params.AsMap() // Get document path documentPath, ok := mapParams[documentPathKey].(string) if !ok || documentPath == "" { return nil, fmt.Errorf("invalid or missing '%s' parameter", documentPathKey) } // Validate document path if err := util.ValidateDocumentPath(documentPath); err != nil { return nil, fmt.Errorf("invalid document path: %w", err) } // Get document data documentDataRaw, ok := mapParams[documentDataKey] if !ok { return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey) } // Get update mask if provided var updatePaths []string if updateMaskRaw, ok := mapParams[updateMaskKey]; ok && updateMaskRaw != nil { if updateMaskArray, ok := updateMaskRaw.([]any); ok { // Use ConvertAnySliceToTyped to convert the slice typedSlice, err := tools.ConvertAnySliceToTyped(updateMaskArray, "string") if err != nil { return nil, fmt.Errorf("failed to convert update mask: %w", err) } updatePaths, ok = typedSlice.([]string) if !ok { return nil, fmt.Errorf("unexpected type conversion error for update mask") } } } // Get return document data flag returnData := false if val, ok := mapParams[returnDocumentDataKey].(bool); ok { returnData = val } // Get the document reference docRef := t.Client.Doc(documentPath) // Prepare update data var writeResult *firestoreapi.WriteResult var writeErr error if len(updatePaths) > 0 { // Use selective field update with update mask updates := make([]firestoreapi.Update, 0, len(updatePaths)) // Convert document data without delete markers dataMap, err := util.JSONToFirestoreValue(documentDataRaw, t.Client) if err != nil { return nil, fmt.Errorf("failed to convert document data: %w", err) } // Ensure it's a map dataMapTyped, ok := dataMap.(map[string]interface{}) if !ok { return nil, fmt.Errorf("document data must be a map") } for _, path := range updatePaths { // Get the value for this path from the document data value, exists := getFieldValue(dataMapTyped, path) if !exists { // Field not in document data but in mask - delete it value = firestoreapi.Delete } updates = append(updates, firestoreapi.Update{ Path: path, Value: value, }) } writeResult, writeErr = docRef.Update(ctx, updates) } else { // Update all fields in the document data (merge) documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client) if err != nil { return nil, fmt.Errorf("failed to convert document data: %w", err) } writeResult, writeErr = docRef.Set(ctx, documentData, firestoreapi.MergeAll) } if writeErr != nil { return nil, fmt.Errorf("failed to update document: %w", writeErr) } // Build the response response := map[string]any{ "documentPath": docRef.Path, "updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"), } // Add document data if requested if returnData { // Fetch the updated document to return the current state snapshot, err := docRef.Get(ctx) if err != nil { return nil, fmt.Errorf("failed to retrieve updated document: %w", err) } // Convert the document data to simple JSON format simplifiedData := util.FirestoreValueToJSON(snapshot.Data()) response["documentData"] = simplifiedData } return response, nil } // getFieldValue retrieves a value from a nested map using a dot-separated path func getFieldValue(data map[string]interface{}, path string) (interface{}, bool) { // Split the path by dots for nested field access parts := strings.Split(path, ".") current := data for i, part := range parts { if i == len(parts)-1 { // Last part - return the value if value, exists := current[part]; exists { return value, true } return nil, false } // Navigate deeper into the structure if next, ok := current[part].(map[string]interface{}); ok { current = next } else { return nil, false } } return nil, false } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { return tools.ParseParams(t.Parameters, data, claims) } func (t Tool) Manifest() tools.Manifest { return t.manifest } func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } func (t Tool) RequiresClientAuthorization() bool { return false } ``` -------------------------------------------------------------------------------- /internal/log/log_test.go: -------------------------------------------------------------------------------- ```go // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package log import ( "bytes" "context" "encoding/json" "log/slog" "strings" "testing" "github.com/google/go-cmp/cmp" ) func TestSeverityToLevel(t *testing.T) { tcs := []struct { name string in string want slog.Level }{ { name: "test debug", in: "Debug", want: slog.LevelDebug, }, { name: "test info", in: "Info", want: slog.LevelInfo, }, { name: "test warn", in: "Warn", want: slog.LevelWarn, }, { name: "test error", in: "Error", want: slog.LevelError, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { got, err := SeverityToLevel(tc.in) if err != nil { t.Fatalf("unexpected error: %s", err) } if got != tc.want { t.Fatalf("incorrect level to severity: got %v, want %v", got, tc.want) } }) } } func TestSeverityToLevelError(t *testing.T) { _, err := SeverityToLevel("fail") if err == nil { t.Fatalf("expected error on incorrect level") } } func TestLevelToSeverity(t *testing.T) { tcs := []struct { name string in string want string }{ { name: "test debug", in: slog.LevelDebug.String(), want: "DEBUG", }, { name: "test info", in: slog.LevelInfo.String(), want: "INFO", }, { name: "test warn", in: slog.LevelWarn.String(), want: "WARN", }, { name: "test error", in: slog.LevelError.String(), want: "ERROR", }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { got, err := levelToSeverity(tc.in) if err != nil { t.Fatalf("unexpected error: %s", err) } if got != tc.want { t.Fatalf("incorrect level to severity: got %v, want %v", got, tc.want) } }) } } func TestLevelToSeverityError(t *testing.T) { _, err := levelToSeverity("fail") if err == nil { t.Fatalf("expected error on incorrect slog level") } } func runLogger(logger Logger, logMsg string) { ctx := context.Background() switch logMsg { case "info": logger.InfoContext(ctx, "log info") case "debug": logger.DebugContext(ctx, "log debug") case "warn": logger.WarnContext(ctx, "log warn") case "error": logger.ErrorContext(ctx, "log error") } } func TestStdLogger(t *testing.T) { tcs := []struct { name string logLevel string logMsg string wantOut string wantErr string }{ { name: "debug logger logging debug", logLevel: "debug", logMsg: "debug", wantOut: "DEBUG \"log debug\" \n", wantErr: "", }, { name: "info logger logging debug", logLevel: "info", logMsg: "debug", wantOut: "", wantErr: "", }, { name: "warn logger logging debug", logLevel: "warn", logMsg: "debug", wantOut: "", wantErr: "", }, { name: "error logger logging debug", logLevel: "error", logMsg: "debug", wantOut: "", wantErr: "", }, { name: "debug logger logging info", logLevel: "debug", logMsg: "info", wantOut: "INFO \"log info\" \n", wantErr: "", }, { name: "info logger logging info", logLevel: "info", logMsg: "info", wantOut: "INFO \"log info\" \n", wantErr: "", }, { name: "warn logger logging info", logLevel: "warn", logMsg: "info", wantOut: "", wantErr: "", }, { name: "error logger logging info", logLevel: "error", logMsg: "info", wantOut: "", wantErr: "", }, { name: "debug logger logging warn", logLevel: "debug", logMsg: "warn", wantOut: "", wantErr: "WARN \"log warn\" \n", }, { name: "info logger logging warn", logLevel: "info", logMsg: "warn", wantOut: "", wantErr: "WARN \"log warn\" \n", }, { name: "warn logger logging warn", logLevel: "warn", logMsg: "warn", wantOut: "", wantErr: "WARN \"log warn\" \n", }, { name: "error logger logging warn", logLevel: "error", logMsg: "warn", wantOut: "", wantErr: "", }, { name: "debug logger logging error", logLevel: "debug", logMsg: "error", wantOut: "", wantErr: "ERROR \"log error\" \n", }, { name: "info logger logging error", logLevel: "info", logMsg: "error", wantOut: "", wantErr: "ERROR \"log error\" \n", }, { name: "warn logger logging error", logLevel: "warn", logMsg: "error", wantOut: "", wantErr: "ERROR \"log error\" \n", }, { name: "error logger logging error", logLevel: "error", logMsg: "error", wantOut: "", wantErr: "ERROR \"log error\" \n", }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { outW := new(bytes.Buffer) errW := new(bytes.Buffer) logger, err := NewStdLogger(outW, errW, tc.logLevel) if err != nil { t.Fatalf("unexpected error: %s", err) } runLogger(logger, tc.logMsg) outWString := outW.String() spaceIndexOut := strings.Index(outWString, " ") gotOut := outWString[spaceIndexOut+1:] errWString := errW.String() spaceIndexErr := strings.Index(errWString, " ") gotErr := errWString[spaceIndexErr+1:] if diff := cmp.Diff(gotOut, tc.wantOut); diff != "" { t.Fatalf("incorrect log: diff %v", diff) } if diff := cmp.Diff(gotErr, tc.wantErr); diff != "" { t.Fatalf("incorrect log: diff %v", diff) } }) } } func TestStructuredLoggerDebugLog(t *testing.T) { tcs := []struct { name string logLevel string logMsg string wantOut map[string]string wantErr map[string]string }{ { name: "debug logger logging debug", logLevel: "debug", logMsg: "debug", wantOut: map[string]string{ "severity": "DEBUG", "message": "log debug", }, wantErr: map[string]string{}, }, { name: "info logger logging debug", logLevel: "info", logMsg: "debug", wantOut: map[string]string{}, wantErr: map[string]string{}, }, { name: "warn logger logging debug", logLevel: "warn", logMsg: "debug", wantOut: map[string]string{}, wantErr: map[string]string{}, }, { name: "error logger logging debug", logLevel: "error", logMsg: "debug", wantOut: map[string]string{}, wantErr: map[string]string{}, }, { name: "debug logger logging info", logLevel: "debug", logMsg: "info", wantOut: map[string]string{ "severity": "INFO", "message": "log info", }, wantErr: map[string]string{}, }, { name: "info logger logging info", logLevel: "info", logMsg: "info", wantOut: map[string]string{ "severity": "INFO", "message": "log info", }, wantErr: map[string]string{}, }, { name: "warn logger logging info", logLevel: "warn", logMsg: "info", wantOut: map[string]string{}, wantErr: map[string]string{}, }, { name: "error logger logging info", logLevel: "error", logMsg: "info", wantOut: map[string]string{}, wantErr: map[string]string{}, }, { name: "debug logger logging warn", logLevel: "debug", logMsg: "warn", wantOut: map[string]string{}, wantErr: map[string]string{ "severity": "WARN", "message": "log warn", }, }, { name: "info logger logging warn", logLevel: "info", logMsg: "warn", wantOut: map[string]string{}, wantErr: map[string]string{ "severity": "WARN", "message": "log warn", }, }, { name: "warn logger logging warn", logLevel: "warn", logMsg: "warn", wantOut: map[string]string{}, wantErr: map[string]string{ "severity": "WARN", "message": "log warn", }, }, { name: "error logger logging warn", logLevel: "error", logMsg: "warn", wantOut: map[string]string{}, wantErr: map[string]string{}, }, { name: "debug logger logging error", logLevel: "debug", logMsg: "error", wantOut: map[string]string{}, wantErr: map[string]string{ "severity": "ERROR", "message": "log error", }, }, { name: "info logger logging error", logLevel: "info", logMsg: "error", wantOut: map[string]string{}, wantErr: map[string]string{ "severity": "ERROR", "message": "log error", }, }, { name: "warn logger logging error", logLevel: "warn", logMsg: "error", wantOut: map[string]string{}, wantErr: map[string]string{ "severity": "ERROR", "message": "log error", }, }, { name: "error logger logging error", logLevel: "error", logMsg: "error", wantOut: map[string]string{}, wantErr: map[string]string{ "severity": "ERROR", "message": "log error", }, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { outW := new(bytes.Buffer) errW := new(bytes.Buffer) logger, err := NewStructuredLogger(outW, errW, tc.logLevel) if err != nil { t.Fatalf("unexpected error: %s", err) } runLogger(logger, tc.logMsg) if len(tc.wantOut) != 0 { got := make(map[string]interface{}) if err := json.Unmarshal(outW.Bytes(), &got); err != nil { t.Fatalf("failed to parse writer") } if got["severity"] != tc.wantOut["severity"] { t.Fatalf("incorrect severity: got %v, want %v", got["severity"], tc.wantOut["severity"]) } } else { if outW.String() != "" { t.Fatalf("incorrect log. got %v, want %v", outW.String(), "") } } if len(tc.wantErr) != 0 { got := make(map[string]interface{}) if err := json.Unmarshal(errW.Bytes(), &got); err != nil { t.Fatalf("failed to parse writer") } if got["severity"] != tc.wantErr["severity"] { t.Fatalf("incorrect severity: got %v, want %v", got["severity"], tc.wantErr["severity"]) } } else { if errW.String() != "" { t.Fatalf("incorrect log. got %v, want %v", errW.String(), "") } } }) } } ``` -------------------------------------------------------------------------------- /internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go: -------------------------------------------------------------------------------- ```go // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package alloydbwaitforoperation import ( "context" "encoding/json" "fmt" "net/http" "strings" "text/template" "time" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" ) const kind string = "alloydb-wait-for-operation" var alloyDBConnectionMessageTemplate = `Your AlloyDB resource is ready. To connect, please configure your environment. The method depends on how you are running the toolbox: **If running locally via stdio:** Update the MCP server configuration with the following environment variables: ` + "```json" + ` { "mcpServers": { "alloydb": { "command": "./PATH/TO/toolbox", "args": ["--prebuilt","alloydb-postgres","--stdio"], "env": { "ALLOYDB_POSTGRES_PROJECT": "{{.Project}}", "ALLOYDB_POSTGRES_REGION": "{{.Region}}", "ALLOYDB_POSTGRES_CLUSTER": "{{.Cluster}}", {{if .Instance}} "ALLOYDB_POSTGRES_INSTANCE": "{{.Instance}}", {{end}} "ALLOYDB_POSTGRES_DATABASE": "postgres", "ALLOYDB_POSTGRES_USER": ""{{.User}}",", "ALLOYDB_POSTGRES_PASSWORD": ""{{.Password}}", } } } } ` + "```" + ` **If running remotely:** For remote deployments, you will need to set the following environment variables in your deployment configuration: ` + "```" + ` ALLOYDB_POSTGRES_PROJECT={{.Project}} ALLOYDB_POSTGRES_REGION={{.Region}} ALLOYDB_POSTGRES_CLUSTER={{.Cluster}} {{if .Instance}}ALLOYDB_POSTGRES_INSTANCE={{.Instance}} {{end}}ALLOYDB_POSTGRES_DATABASE=postgres ALLOYDB_POSTGRES_USER=<your-user> ALLOYDB_POSTGRES_PASSWORD=<your-password> ` + "```" + ` Please refer to the official documentation for guidance on deploying the toolbox: - Deploying the Toolbox: https://googleapis.github.io/genai-toolbox/how-to/deploy_toolbox/ - Deploying on GKE: https://googleapis.github.io/genai-toolbox/how-to/deploy_gke/ ` func init() { if !tools.Register(kind, newConfig) { panic(fmt.Sprintf("tool kind %q already registered", kind)) } } func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { actual := Config{Name: name} if err := decoder.DecodeContext(ctx, &actual); err != nil { return nil, err } return actual, nil } // Config defines the configuration for the wait-for-operation tool. type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` Source string `yaml:"source" validate:"required"` Description string `yaml:"description"` AuthRequired []string `yaml:"authRequired"` // Polling configuration Delay string `yaml:"delay"` MaxDelay string `yaml:"maxDelay"` Multiplier float64 `yaml:"multiplier"` MaxRetries int `yaml:"maxRetries"` } // validate interface var _ tools.ToolConfig = Config{} // ToolConfigKind returns the kind of the tool. func (cfg Config) ToolConfigKind() string { return kind } // Initialize initializes the tool from the configuration. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { rawS, ok := srcs[cfg.Source] if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } s, ok := rawS.(*alloydbadmin.Source) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) } allParameters := tools.Parameters{ tools.NewStringParameter("project", "The project ID"), tools.NewStringParameter("location", "The location ID"), tools.NewStringParameter("operation", "The operation ID"), } paramManifest := allParameters.Manifest() description := cfg.Description if description == "" { description = "This will poll on operations API until the operation is done. For checking operation status we need projectId, locationID and operationId. Once instance is created give follow up steps on how to use the variables to bring data plane MCP server up in local and remote setup." } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters) var delay time.Duration if cfg.Delay == "" { delay = 3 * time.Second } else { var err error delay, err = time.ParseDuration(cfg.Delay) if err != nil { return nil, fmt.Errorf("invalid value for delay: %w", err) } } var maxDelay time.Duration if cfg.MaxDelay == "" { maxDelay = 4 * time.Minute } else { var err error maxDelay, err = time.ParseDuration(cfg.MaxDelay) if err != nil { return nil, fmt.Errorf("invalid value for maxDelay: %w", err) } } multiplier := cfg.Multiplier if multiplier == 0 { multiplier = 2.0 } maxRetries := cfg.MaxRetries if maxRetries == 0 { maxRetries = 10 } return Tool{ Name: cfg.Name, Kind: kind, AuthRequired: cfg.AuthRequired, Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, Delay: delay, MaxDelay: maxDelay, Multiplier: multiplier, MaxRetries: maxRetries, }, nil } // Tool represents the wait-for-operation tool. type Tool struct { Name string `yaml:"name"` Kind string `yaml:"kind"` Description string `yaml:"description"` AuthRequired []string `yaml:"authRequired"` Source *alloydbadmin.Source AllParams tools.Parameters `yaml:"allParams"` // Polling configuration Delay time.Duration MaxDelay time.Duration Multiplier float64 MaxRetries int Client *http.Client manifest tools.Manifest mcpManifest tools.McpManifest } // Invoke executes the tool's logic. func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok { return nil, fmt.Errorf("missing 'project' parameter") } location, ok := paramsMap["location"].(string) if !ok { return nil, fmt.Errorf("missing 'location' parameter") } operation, ok := paramsMap["operation"].(string) if !ok { return nil, fmt.Errorf("missing 'operation' parameter") } service, err := t.Source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) defer cancel() name := fmt.Sprintf("projects/%s/locations/%s/operations/%s", project, location, operation) delay := t.Delay maxDelay := t.MaxDelay multiplier := t.Multiplier maxRetries := t.MaxRetries retries := 0 for retries < maxRetries { select { case <-ctx.Done(): return nil, fmt.Errorf("timed out waiting for operation: %w", ctx.Err()) default: } op, err := service.Projects.Locations.Operations.Get(name).Do() if err != nil { fmt.Printf("error getting operation: %s, retrying in %v\n", err, delay) } else { if op.Done { if op.Error != nil { var errorBytes []byte errorBytes, err = json.Marshal(op.Error) if err != nil { return nil, fmt.Errorf("operation finished with error but could not marshal error object: %w", err) } return nil, fmt.Errorf("operation finished with error: %s", string(errorBytes)) } var opBytes []byte opBytes, err = op.MarshalJSON() if err != nil { return nil, fmt.Errorf("could not marshal operation: %w", err) } if msg, ok := t.generateAlloyDBConnectionMessage(map[string]any{"response": op.Response}); ok { return msg, nil } return string(opBytes), nil } fmt.Printf("Operation not complete, retrying in %v\n", delay) } time.Sleep(delay) delay = time.Duration(float64(delay) * multiplier) if delay > maxDelay { delay = maxDelay } retries++ } return nil, fmt.Errorf("exceeded max retries waiting for operation") } func (t Tool) generateAlloyDBConnectionMessage(responseData map[string]any) (string, bool) { resourceName, ok := responseData["name"].(string) if !ok { return "", false } parts := strings.Split(resourceName, "/") var project, region, cluster, instance string // Expected format: projects/{project}/locations/{location}/clusters/{cluster} // or projects/{project}/locations/{location}/clusters/{cluster}/instances/{instance} if len(parts) < 6 || parts[0] != "projects" || parts[2] != "locations" || parts[4] != "clusters" { return "", false } project = parts[1] region = parts[3] cluster = parts[5] if len(parts) >= 8 && parts[6] == "instances" { instance = parts[7] } else { return "", false } tmpl, err := template.New("alloydb-connection").Parse(alloyDBConnectionMessageTemplate) if err != nil { // This should not happen with a static template return fmt.Sprintf("template parsing error: %v", err), false } data := struct { Project string Region string Cluster string Instance string }{ Project: project, Region: region, Cluster: cluster, Instance: instance, } var b strings.Builder if err := tmpl.Execute(&b, data); err != nil { return fmt.Sprintf("template execution error: %v", err), false } return b.String(), true } // ParseParams parses the parameters for the tool. func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { return tools.ParseParams(t.AllParams, data, claims) } // Manifest returns the tool's manifest. func (t Tool) Manifest() tools.Manifest { return t.manifest } // McpManifest returns the tool's MCP manifest. func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } // Authorized checks if the tool is authorized. func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } func (t Tool) RequiresClientAuthorization() bool { return t.Source.UseClientAuthorization() } ``` -------------------------------------------------------------------------------- /docs/en/resources/tools/_index.md: -------------------------------------------------------------------------------- ```markdown --- title: "Tools" type: docs weight: 2 description: > Tools define actions an agent can take -- such as reading and writing to a source. --- A tool represents an action your agent can take, such as running a SQL statement. You can define Tools as a map in the `tools` section of your `tools.yaml` file. Typically, a tool will require a source to act on: ```yaml tools: search_flights_by_number: kind: postgres-sql source: my-pg-instance statement: | SELECT * FROM flights WHERE airline = $1 AND flight_number = $2 LIMIT 10 description: | Use this tool to get information for a specific flight. Takes an airline code and flight number and returns info on the flight. Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number. An airline code is a code for an airline service consisting of a two-character airline designator and followed by a flight number, which is a 1 to 4 digit number. For example, if given CY 0123, the airline is "CY", and flight_number is "123". Another example for this is DL 1234, the airline is "DL", and flight_number is "1234". If the tool returns more than one option choose the date closest to today. Example: {{ "airline": "CY", "flight_number": "888", }} Example: {{ "airline": "DL", "flight_number": "1234", }} parameters: - name: airline type: string description: Airline unique 2 letter identifier - name: flight_number type: string description: 1 to 4 digit number ``` ## Specifying Parameters Parameters for each Tool will define what inputs the agent will need to provide to invoke them. Parameters should be pass as a list of Parameter objects: ```yaml parameters: - name: airline type: string description: Airline unique 2 letter identifier - name: flight_number type: string description: 1 to 4 digit number ``` ### Basic Parameters Basic parameters types include `string`, `integer`, `float`, `boolean` types. In most cases, the description will be provided to the LLM as context on specifying the parameter. ```yaml parameters: - name: airline type: string description: Airline unique 2 letter identifier ``` | **field** | **type** | **required** | **description** | |-------------|:---------------:|:------------:|-----------------------------------------------------------------------------| | name | string | true | Name of the parameter. | | type | string | true | Must be one of "string", "integer", "float", "boolean" "array" | | description | string | true | Natural language description of the parameter to describe it to the agent. | | default | parameter type | false | Default value of the parameter. If provided, `required` will be `false`. | | required | bool | false | Indicate if the parameter is required. Default to `true`. | ### Array Parameters The `array` type is a list of items passed in as a single parameter. To use the `array` type, you must also specify what kind of items are in the list using the items field: ```yaml parameters: - name: preferred_airlines type: array description: A list of airline, ordered by preference. items: name: name type: string description: Name of the airline. statement: | SELECT * FROM airlines WHERE preferred_airlines = ANY($1); ``` | **field** | **type** | **required** | **description** | |-------------|:----------------:|:------------:|-----------------------------------------------------------------------------| | name | string | true | Name of the parameter. | | type | string | true | Must be "array" | | description | string | true | Natural language description of the parameter to describe it to the agent. | | default | parameter type | false | Default value of the parameter. If provided, `required` will be `false`. | | required | bool | false | Indicate if the parameter is required. Default to `true`. | | items | parameter object | true | Specify a Parameter object for the type of the values in the array. | {{< notice note >}} Items in array should not have a `default` or `required` value. If provided, it will be ignored. {{< /notice >}} ### Map Parameters The map type is a collection of key-value pairs. It can be configured in two ways: - Generic Map: By default, it accepts values of any primitive type (string, integer, float, boolean), allowing for mixed data. - Typed Map: By setting the valueType field, you can enforce that all values within the map must be of the same specified type. #### Generic Map (Mixed Value Types) This is the default behavior when valueType is omitted. It's useful for passing a flexible group of settings. ```yaml parameters: - name: execution_context type: map description: A flexible set of key-value pairs for the execution environment. ``` #### Typed Map Specify valueType to ensure all values in the map are of the same type. An error will be thrown in case of value type mismatch. ```yaml parameters: - name: user_scores type: map description: A map of user IDs to their scores. All scores must be integers. valueType: integer # This enforces the value type for all entries. ``` ### Authenticated Parameters Authenticated parameters are automatically populated with user information decoded from [ID tokens](../authServices/#specifying-id-tokens-from-clients) that are passed in request headers. They do not take input values in request bodies like other parameters. To use authenticated parameters, you must configure the tool to map the required [authServices](../authServices/) to specific claims within the user's ID token. ```yaml tools: search_flights_by_user_id: kind: postgres-sql source: my-pg-instance statement: | SELECT * FROM flights WHERE user_id = $1 parameters: - name: user_id type: string description: Auto-populated from Google login authServices: # Refer to one of the `authServices` defined - name: my-google-auth # `sub` is the OIDC claim field for user ID field: sub ``` | **field** | **type** | **required** | **description** | |-----------|:--------:|:------------:|-----------------------------------------------------------------------------------------| | name | string | true | Name of the [authServices](../authServices/) used to verify the OIDC auth token. | | field | string | true | Claim field decoded from the OIDC token used to auto-populate this parameter. | ### Template Parameters Template parameters types include `string`, `integer`, `float`, `boolean` types. In most cases, the description will be provided to the LLM as context on specifying the parameter. Template parameters will be inserted into the SQL statement before executing the prepared statement. They will be inserted without quotes, so to insert a string using template parameters, quotes must be explicitly added within the string. Template parameter arrays can also be used similarly to basic parameters, and array items must be strings. Once inserted into the SQL statement, the outer layer of quotes will be removed. Therefore to insert strings into the SQL statement, a set of quotes must be explicitly added within the string. {{< notice warning >}} Because template parameters can directly replace identifiers, column names, and table names, they are prone to SQL injections. Basic parameters are preferred for performance and safety reasons. {{< /notice >}} ```yaml tools: select_columns_from_table: kind: postgres-sql source: my-pg-instance statement: | SELECT {{array .columnNames}} FROM {{.tableName}} description: | Use this tool to list all information from a specific table. Example: {{ "tableName": "flights", "columnNames": ["id", "name"] }} templateParameters: - name: tableName type: string description: Table to select from - name: columnNames type: array description: The columns to select items: name: column type: string description: Name of a column to select ``` | **field** | **type** | **required** | **description** | |-------------|:----------------:|:-------------:|-------------------------------------------------------------------------------------| | name | string | true | Name of the template parameter. | | type | string | true | Must be one of "string", "integer", "float", "boolean" "array" | | description | string | true | Natural language description of the template parameter to describe it to the agent. | | items | parameter object |true (if array)| Specify a Parameter object for the type of the values in the array (string only). | ## Authorized Invocations You can require an authorization check for any Tool invocation request by specifying an `authRequired` field. Specify a list of [authServices](../authServices/) defined in the previous section. ```yaml tools: search_all_flight: kind: postgres-sql source: my-pg-instance statement: | SELECT * FROM flights # A list of `authServices` defined previously authRequired: - my-google-auth - other-auth-service ``` ## Kinds of tools ``` -------------------------------------------------------------------------------- /tests/trino/trino_integration_test.go: -------------------------------------------------------------------------------- ```go // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trino import ( "context" "database/sql" "fmt" "os" "regexp" "strings" "testing" "time" "github.com/google/uuid" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/tests" _ "github.com/trinodb/trino-go-client/trino" // Import Trino SQL driver ) var ( TrinoSourceKind = "trino" TrinoToolKind = "trino-sql" TrinoHost = os.Getenv("TRINO_HOST") TrinoPort = os.Getenv("TRINO_PORT") TrinoUser = os.Getenv("TRINO_USER") TrinoPass = os.Getenv("TRINO_PASS") TrinoCatalog = os.Getenv("TRINO_CATALOG") TrinoSchema = os.Getenv("TRINO_SCHEMA") ) func getTrinoVars(t *testing.T) map[string]any { switch "" { case TrinoHost: t.Fatal("'TRINO_HOST' not set") case TrinoPort: t.Fatal("'TRINO_PORT' not set") // TrinoUser is optional for anonymous access case TrinoCatalog: t.Fatal("'TRINO_CATALOG' not set") case TrinoSchema: t.Fatal("'TRINO_SCHEMA' not set") } return map[string]any{ "kind": TrinoSourceKind, "host": TrinoHost, "port": TrinoPort, "user": TrinoUser, "password": TrinoPass, "catalog": TrinoCatalog, "schema": TrinoSchema, } } // initTrinoConnectionPool creates a Trino connection pool (copied from trino.go) func initTrinoConnectionPool(host, port, user, pass, catalog, schema string) (*sql.DB, error) { dsn, err := buildTrinoDSN(host, port, user, pass, catalog, schema, "", "", false, false) if err != nil { return nil, fmt.Errorf("failed to build DSN: %w", err) } db, err := sql.Open("trino", dsn) if err != nil { return nil, fmt.Errorf("failed to open connection: %w", err) } // Configure connection pool db.SetMaxOpenConns(10) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(time.Hour) return db, nil } // buildTrinoDSN builds a Trino DSN string (simplified version from trino.go) func buildTrinoDSN(host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool) (string, error) { scheme := "http" if sslEnabled { scheme = "https" } // Build base DSN without user info dsn := fmt.Sprintf("%s://%s:%s?catalog=%s&schema=%s", scheme, host, port, catalog, schema) // Add user authentication if provided if user != "" { if password != "" { dsn = fmt.Sprintf("%s://%s:%s@%s:%s?catalog=%s&schema=%s", scheme, user, password, host, port, catalog, schema) } else { dsn = fmt.Sprintf("%s://%s@%s:%s?catalog=%s&schema=%s", scheme, user, host, port, catalog, schema) } } if queryTimeout != "" { dsn += "&queryTimeout=" + queryTimeout } if accessToken != "" { dsn += "&accessToken=" + accessToken } if kerberosEnabled { dsn += "&KerberosEnabled=true" } return dsn, nil } // getTrinoParamToolInfo returns statements and param for my-tool trino-sql kind func getTrinoParamToolInfo(tableName string) (string, string, string, string, string, string, []any) { createStatement := fmt.Sprintf("CREATE TABLE %s (id BIGINT NOT NULL, name VARCHAR(255))", tableName) insertStatement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, ?), (2, ?), (3, ?), (4, ?)", tableName) toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?", tableName) idParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ?", tableName) nameParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE name = ?", tableName) arrayToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id IN (?, ?) AND name IN (?, ?)", tableName) // Trino doesn't use ANY() like MySQL/PostgreSQL params := []any{"Alice", "Jane", "Sid", nil} return createStatement, insertStatement, toolStatement, idParamStatement, nameParamStatement, arrayToolStatement, params } // getTrinoAuthToolInfo returns statements and param of my-auth-tool for trino-sql kind func getTrinoAuthToolInfo(tableName string) (string, string, string, []any) { createStatement := fmt.Sprintf("CREATE TABLE %s (id BIGINT NOT NULL, name VARCHAR(255), email VARCHAR(255))", tableName) insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (1, ?, ?), (2, ?, ?)", tableName) toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ?", tableName) params := []any{"Alice", tests.ServiceAccountEmail, "Jane", "[email protected]"} return createStatement, insertStatement, toolStatement, params } // getTrinoTmplToolStatement returns statements and param for template parameter test cases for trino-sql kind func getTrinoTmplToolStatement() (string, string) { tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ?" tmplSelectFilterCombined := "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = ?" return tmplSelectCombined, tmplSelectFilterCombined } // getTrinoWants return the expected wants for trino func getTrinoWants() (string, string, string, string) { select1Want := `[{"_col0":1}]` failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: trino: query failed (200 OK): \"USER_ERROR: line 1:1: mismatched input 'SELEC'. Expecting: 'ALTER', 'ANALYZE', 'CALL', 'COMMENT', 'COMMIT', 'CREATE', 'DEALLOCATE', 'DELETE', 'DENY', 'DESC', 'DESCRIBE', 'DROP', 'EXECUTE', 'EXPLAIN', 'GRANT', 'INSERT', 'MERGE', 'PREPARE', 'REFRESH', 'RESET', 'REVOKE', 'ROLLBACK', 'SET', 'SHOW', 'START', 'TRUNCATE', 'UPDATE', 'USE', 'WITH', \u003cquery\u003e\""}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id BIGINT NOT NULL, name VARCHAR(255))"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"_col0\":1}"}]}}` return select1Want, failInvocationWant, createTableStatement, mcpSelect1Want } // setupTrinoTable creates and inserts data into a table of tool // compatible with trino-sql tool func setupTrinoTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) { err := pool.PingContext(ctx) if err != nil { t.Fatalf("unable to connect to test database: %s", err) } // Create table _, err = pool.QueryContext(ctx, createStatement) if err != nil { t.Fatalf("unable to create test table %s: %s", tableName, err) } // Insert test data _, err = pool.QueryContext(ctx, insertStatement, params...) if err != nil { t.Fatalf("unable to insert test data: %s", err) } return func(t *testing.T) { // tear down test _, err = pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", tableName)) if err != nil { t.Errorf("Teardown failed: %s", err) } } } // addTrinoExecuteSqlConfig gets the tools config for `trino-execute-sql` func addTrinoExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any { tools, ok := config["tools"].(map[string]any) if !ok { t.Fatalf("unable to get tools from config") } tools["my-exec-sql-tool"] = map[string]any{ "kind": "trino-execute-sql", "source": "my-instance", "description": "Tool to execute sql", } tools["my-auth-exec-sql-tool"] = map[string]any{ "kind": "trino-execute-sql", "source": "my-instance", "description": "Tool to execute sql", "authRequired": []string{ "my-google-auth", }, } config["tools"] = tools return config } func TestTrinoToolEndpoints(t *testing.T) { sourceConfig := getTrinoVars(t) ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() var args []string pool, err := initTrinoConnectionPool(TrinoHost, TrinoPort, TrinoUser, TrinoPass, TrinoCatalog, TrinoSchema) if err != nil { t.Fatalf("unable to create Trino connection pool: %s", err) } // create table name with UUID tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") // set up data for param tool createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getTrinoParamToolInfo(tableNameParam) teardownTable1 := setupTrinoTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) defer teardownTable1(t) // set up data for auth tool createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getTrinoAuthToolInfo(tableNameAuth) teardownTable2 := setupTrinoTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) defer teardownTable2(t) // Write config into a file and pass it to command toolsFile := tests.GetToolsConfig(sourceConfig, TrinoToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) toolsFile = addTrinoExecuteSqlConfig(t, toolsFile) tmplSelectCombined, tmplSelectFilterCombined := getTrinoTmplToolStatement() toolsFile = tests.AddTemplateParamConfig(t, toolsFile, TrinoToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) if err != nil { t.Fatalf("command initialization returned an error: %s", err) } defer cleanup() waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) if err != nil { t.Logf("toolbox command logs: \n%s", out) t.Fatalf("toolbox didn't start successfully: %s", err) } // Get configs for tests select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want := getTrinoWants() // Run tests tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, select1Want, tests.DisableArrayTest()) tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want) tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithInsert1Want(`[{"rows":1}]`)) } ``` -------------------------------------------------------------------------------- /internal/tools/postgres/postgreslisttables/postgreslisttables.go: -------------------------------------------------------------------------------- ```go // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package postgreslisttables import ( "context" "fmt" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/jackc/pgx/v5/pgxpool" ) const kind string = "postgres-list-tables" const listTablesStatement = ` WITH desired_relkinds AS ( SELECT ARRAY['r', 'p']::char[] AS kinds -- Always consider both 'TABLE' and 'PARTITIONED TABLE' ), table_info AS ( SELECT t.oid AS table_oid, ns.nspname AS schema_name, t.relname AS table_name, pg_get_userbyid(t.relowner) AS table_owner, obj_description(t.oid, 'pg_class') AS table_comment, t.relkind AS object_kind FROM pg_class t JOIN pg_namespace ns ON ns.oid = t.relnamespace CROSS JOIN desired_relkinds dk WHERE t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p') AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%' ), columns_info AS ( SELECT att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type, att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable, pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped ), constraints_info AS ( SELECT con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition, CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type, (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns, NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table, (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid ), indexes_info AS ( SELECT idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition, idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method, (SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid ), triggers_info AS ( SELECT tg.tgrelid AS table_oid, tg.tgname AS trigger_name, pg_get_triggerdef(tg.oid) AS trigger_definition, tg.tgenabled AS trigger_enabled_state FROM pg_trigger tg JOIN table_info ti ON tg.tgrelid = ti.table_oid WHERE NOT tg.tgisinternal ) SELECT ti.schema_name, ti.table_name AS object_name, CASE WHEN $2 = 'simple' THEN -- IF format is 'simple', return basic JSON json_build_object('name', ti.table_name) ELSE json_build_object( 'schema_name', ti.schema_name, 'object_name', ti.table_name, 'object_type', CASE ti.object_kind WHEN 'r' THEN 'TABLE' WHEN 'p' THEN 'PARTITIONED TABLE' ELSE ti.object_kind::text -- Should not happen due to WHERE clause END, 'owner', ti.table_owner, 'comment', ti.table_comment, 'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json), 'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json), 'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json), 'triggers', COALESCE((SELECT json_agg(json_build_object('trigger_name',tri.trigger_name,'trigger_definition',tri.trigger_definition,'trigger_enabled_state',tri.trigger_enabled_state)) FROM triggers_info tri WHERE tri.table_oid = ti.table_oid), '[]'::json) ) END AS object_details FROM table_info ti ORDER BY ti.schema_name, ti.table_name; ` func init() { if !tools.Register(kind, newConfig) { panic(fmt.Sprintf("tool kind %q already registered", kind)) } } func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { actual := Config{Name: name} if err := decoder.DecodeContext(ctx, &actual); err != nil { return nil, err } return actual, nil } type compatibleSource interface { PostgresPool() *pgxpool.Pool } // validate compatible sources are still compatible var _ compatibleSource = &alloydbpg.Source{} var _ compatibleSource = &cloudsqlpg.Source{} var _ compatibleSource = &postgres.Source{} var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` Source string `yaml:"source" validate:"required"` Description string `yaml:"description" validate:"required"` AuthRequired []string `yaml:"authRequired"` } // validate interface var _ tools.ToolConfig = Config{} func (cfg Config) ToolConfigKind() string { return kind } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { // verify source exists rawS, ok := srcs[cfg.Source] if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) } allParameters := tools.Parameters{ tools.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), tools.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."), } paramManifest := allParameters.Manifest() mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters) t := Tool{ Name: cfg.Name, Kind: kind, AuthRequired: cfg.AuthRequired, AllParams: allParameters, Pool: s.PostgresPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } return t, nil } // validate interface var _ tools.Tool = Tool{} type Tool struct { Name string `yaml:"name"` Kind string `yaml:"kind"` AuthRequired []string `yaml:"authRequired"` AllParams tools.Parameters `yaml:"allParams"` Pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) if !ok { return nil, fmt.Errorf("invalid 'table_names' parameter; expected a string") } outputFormat, _ := paramsMap["output_format"].(string) if outputFormat != "simple" && outputFormat != "detailed" { return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) } results, err := t.Pool.Query(ctx, listTablesStatement, tableNames, outputFormat) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } defer results.Close() fields := results.FieldDescriptions() var out []map[string]any for results.Next() { values, err := results.Values() if err != nil { return nil, fmt.Errorf("unable to parse row: %w", err) } rowMap := make(map[string]any) for i, field := range fields { rowMap[string(field.Name)] = values[i] } out = append(out, rowMap) } if err := results.Err(); err != nil { return nil, fmt.Errorf("error reading query results: %w", err) } return out, nil } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { return tools.ParseParams(t.AllParams, data, claims) } func (t Tool) Manifest() tools.Manifest { return t.manifest } func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } func (t Tool) RequiresClientAuthorization() bool { return false } ``` -------------------------------------------------------------------------------- /internal/server/api.go: -------------------------------------------------------------------------------- ```go // Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package server import ( "encoding/json" "errors" "fmt" "net/http" "strings" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" ) // apiRouter creates a router that represents the routes under /api func apiRouter(s *Server) (chi.Router, error) { r := chi.NewRouter() r.Use(middleware.AllowContentType("application/json")) r.Use(middleware.StripSlashes) r.Use(render.SetContentType(render.ContentTypeJSON)) r.Get("/toolset", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) }) r.Get("/toolset/{toolsetName}", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) }) r.Route("/tool/{toolName}", func(r chi.Router) { r.Get("/", func(w http.ResponseWriter, r *http.Request) { toolGetHandler(s, w, r) }) r.Post("/invoke", func(w http.ResponseWriter, r *http.Request) { toolInvokeHandler(s, w, r) }) }) return r, nil } // toolsetHandler handles the request for information about a Toolset. func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) { ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/toolset/get") r = r.WithContext(ctx) toolsetName := chi.URLParam(r, "toolsetName") s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) span.SetAttributes(attribute.String("toolset_name", toolsetName)) var err error defer func() { if err != nil { span.SetStatus(codes.Error, err.Error()) } span.End() status := "success" if err != nil { status = "error" } s.instrumentation.ToolsetGet.Add( r.Context(), 1, metric.WithAttributes(attribute.String("toolbox.name", toolsetName)), metric.WithAttributes(attribute.String("toolbox.operation.status", status)), ) }() toolset, ok := s.ResourceMgr.GetToolset(toolsetName) if !ok { err = fmt.Errorf("toolset %q does not exist", toolsetName) s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) return } render.JSON(w, r, toolset.Manifest) } // toolGetHandler handles requests for a single Tool. func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) { ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/get") r = r.WithContext(ctx) toolName := chi.URLParam(r, "toolName") s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) span.SetAttributes(attribute.String("tool_name", toolName)) var err error defer func() { if err != nil { span.SetStatus(codes.Error, err.Error()) } span.End() status := "success" if err != nil { status = "error" } s.instrumentation.ToolGet.Add( r.Context(), 1, metric.WithAttributes(attribute.String("toolbox.name", toolName)), metric.WithAttributes(attribute.String("toolbox.operation.status", status)), ) }() tool, ok := s.ResourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) return } // TODO: this can be optimized later with some caching m := tools.ToolsetManifest{ ServerVersion: s.version, ToolsManifest: map[string]tools.Manifest{ toolName: tool.Manifest(), }, } render.JSON(w, r, m) } // toolInvokeHandler handles the API request to invoke a specific Tool. func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/tool/invoke") r = r.WithContext(ctx) ctx = util.WithLogger(r.Context(), s.logger) toolName := chi.URLParam(r, "toolName") s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) span.SetAttributes(attribute.String("tool_name", toolName)) var err error defer func() { if err != nil { span.SetStatus(codes.Error, err.Error()) } span.End() status := "success" if err != nil { status = "error" } s.instrumentation.ToolInvoke.Add( r.Context(), 1, metric.WithAttributes(attribute.String("toolbox.name", toolName)), metric.WithAttributes(attribute.String("toolbox.operation.status", status)), ) }() tool, ok := s.ResourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) return } // Extract OAuth access token from the "Authorization" header (currently for // BigQuery end-user credentials usage only) accessToken := tools.AccessToken(r.Header.Get("Authorization")) // Check if this specific tool requires the standard authorization header if tool.RequiresClientAuthorization() { if accessToken == "" { err = fmt.Errorf("tool requires client authorization but access token is missing from the request header") s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) return } } // Tool authentication // claimsFromAuth maps the name of the authservice to the claims retrieved from it. claimsFromAuth := make(map[string]map[string]any) for _, aS := range s.ResourceMgr.GetAuthServiceMap() { claims, err := aS.GetClaimsFromHeader(ctx, r.Header) if err != nil { s.logger.DebugContext(ctx, err.Error()) continue } if claims == nil { // authService not present in header continue } claimsFromAuth[aS.GetName()] = claims } // Tool authorization check verifiedAuthServices := make([]string, len(claimsFromAuth)) i := 0 for k := range claimsFromAuth { verifiedAuthServices[i] = k i++ } // Check if any of the specified auth services is verified isAuthorized := tool.Authorized(verifiedAuthServices) if !isAuthorized { err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers") s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) return } s.logger.DebugContext(ctx, "tool invocation authorized") var data map[string]any if err = util.DecodeJSON(r.Body, &data); err != nil { render.Status(r, http.StatusBadRequest) err = fmt.Errorf("request body was invalid JSON: %w", err) s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) return } params, err := tool.ParseParams(data, claimsFromAuth) if err != nil { // If auth error, return 401 if errors.Is(err, tools.ErrUnauthorized) { s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err)) _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized)) return } err = fmt.Errorf("provided parameters were invalid: %w", err) s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) return } s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) res, err := tool.Invoke(ctx, params, accessToken) // Determine what error to return to the users. if err != nil { errStr := err.Error() var statusCode int // Upstream API auth error propagation switch { case strings.Contains(errStr, "Error 401"): statusCode = http.StatusUnauthorized case strings.Contains(errStr, "Error 403"): statusCode = http.StatusForbidden } if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { if tool.RequiresClientAuthorization() { // Propagate the original 401/403 error. s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err)) _ = render.Render(w, r, newErrResponse(err, statusCode)) return } // ADC lacking permission or credentials configuration error. internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err) s.logger.ErrorContext(ctx, internalErr.Error()) _ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError)) return } err = fmt.Errorf("error while invoking tool: %w", err) s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) return } resMarshal, err := json.Marshal(res) if err != nil { err = fmt.Errorf("unable to marshal result: %w", err) s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) return } _ = render.Render(w, r, &resultResponse{Result: string(resMarshal)}) } var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads. // resultResponse is the response sent back when the tool was invocated successfully. type resultResponse struct { Result string `json:"result"` // result of tool invocation } // Render renders a single payload and respond to the client request. func (rr resultResponse) Render(w http.ResponseWriter, r *http.Request) error { render.Status(r, http.StatusOK) return nil } var _ render.Renderer = &errResponse{} // Renderer interface for managing response payloads. // newErrResponse is a helper function initializing an ErrResponse func newErrResponse(err error, code int) *errResponse { return &errResponse{ Err: err, HTTPStatusCode: code, StatusText: http.StatusText(code), ErrorText: err.Error(), } } // errResponse is the response sent back when an error has been encountered. type errResponse struct { Err error `json:"-"` // low-level runtime error HTTPStatusCode int `json:"-"` // http response status code StatusText string `json:"status"` // user-level status message ErrorText string `json:"error,omitempty"` // application-level error message, for debugging } func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error { render.Status(r, e.HTTPStatusCode) return nil } ``` -------------------------------------------------------------------------------- /docs/en/resources/sources/bigquery.md: -------------------------------------------------------------------------------- ```markdown --- title: "BigQuery" type: docs weight: 1 description: > BigQuery is Google Cloud's fully managed, petabyte-scale, and cost-effective analytics data warehouse that lets you run analytics over vast amounts of data in near real time. With BigQuery, there's no infrastructure to set up or manage, letting you focus on finding meaningful insights using GoogleSQL and taking advantage of flexible pricing models across on-demand and flat-rate options. --- # BigQuery Source [BigQuery][bigquery-docs] is Google Cloud's fully managed, petabyte-scale, and cost-effective analytics data warehouse that lets you run analytics over vast amounts of data in near real time. With BigQuery, there's no infrastructure to set up or manage, letting you focus on finding meaningful insights using GoogleSQL and taking advantage of flexible pricing models across on-demand and flat-rate options. If you are new to BigQuery, you can try to [load and query data with the bq tool][bigquery-quickstart-cli]. BigQuery uses [GoogleSQL][bigquery-googlesql] for querying data. GoogleSQL is an ANSI-compliant structured query language (SQL) that is also implemented for other Google Cloud services. SQL queries are handled by cluster nodes in the same way as NoSQL data requests. Therefore, the same best practices apply when creating SQL queries to run against your BigQuery data, such as avoiding full table scans or complex filters. [bigquery-docs]: https://cloud.google.com/bigquery/docs [bigquery-quickstart-cli]: https://cloud.google.com/bigquery/docs/quickstarts/quickstart-command-line [bigquery-googlesql]: https://cloud.google.com/bigquery/docs/reference/standard-sql/ ## Available Tools - [`bigquery-analyze-contribution`](../tools/bigquery/bigquery-analyze-contribution.md) Performs contribution analysis, also called key driver analysis in BigQuery. - [`bigquery-conversational-analytics`](../tools/bigquery/bigquery-conversational-analytics.md) Allows conversational interaction with a BigQuery source. - [`bigquery-execute-sql`](../tools/bigquery/bigquery-execute-sql.md) Execute structured queries using parameters. - [`bigquery-forecast`](../tools/bigquery/bigquery-forecast.md) Forecasts time series data in BigQuery. - [`bigquery-get-dataset-info`](../tools/bigquery/bigquery-get-dataset-info.md) Retrieve metadata for a specific dataset. - [`bigquery-get-table-info`](../tools/bigquery/bigquery-get-table-info.md) Retrieve metadata for a specific table. - [`bigquery-list-dataset-ids`](../tools/bigquery/bigquery-list-dataset-ids.md) List available dataset IDs. - [`bigquery-list-table-ids`](../tools/bigquery/bigquery-list-table-ids.md) List tables in a given dataset. - [`bigquery-sql`](../tools/bigquery/bigquery-sql.md) Run SQL queries directly against BigQuery datasets. - [`bigquery-search-catalog`](../tools/bigquery/bigquery-search_catalog.md) List all entries in Dataplex Catalog (e.g. tables, views, models) that matches given user query. ### Pre-built Configurations - [BigQuery using MCP](https://googleapis.github.io/genai-toolbox/how-to/connect-ide/bigquery_mcp/) Connect your IDE to BigQuery using Toolbox. ## Requirements ### IAM Permissions BigQuery uses [Identity and Access Management (IAM)][iam-overview] to control user and group access to BigQuery resources like projects, datasets, and tables. ### Authentication via Application Default Credentials (ADC) By **default**, Toolbox will use your [Application Default Credentials (ADC)][adc] to authorize and authenticate when interacting with [BigQuery][bigquery-docs]. When using this method, you need to ensure the IAM identity associated with your ADC (such as a service account) has the correct permissions for the queries you intend to run. Common roles include `roles/bigquery.user` (which includes permissions to run jobs and read data) or `roles/bigbigquery.dataViewer`. Follow this [guide][set-adc] to set up your ADC. ### Authentication via User's OAuth Access Token If the `useClientOAuth` parameter is set to `true`, Toolbox will instead use the OAuth access token for authentication. This token is parsed from the `Authorization` header passed in with the tool invocation request. This method allows Toolbox to make queries to [BigQuery][bigquery-docs] on behalf of the client or the end-user. When using this on-behalf-of authentication, you must ensure that the identity used has been granted the correct IAM permissions. [iam-overview]: <https://cloud.google.com/bigquery/docs/access-control> [adc]: <https://cloud.google.com/docs/authentication#adc> [set-adc]: <https://cloud.google.com/docs/authentication/provide-credentials-adc> ## Example Initialize a BigQuery source that uses ADC: ```yaml sources: my-bigquery-source: kind: "bigquery" project: "my-project-id" # location: "US" # Optional: Specifies the location for query jobs. # writeMode: "allowed" # One of: allowed, blocked, protected. Defaults to "allowed". # allowedDatasets: # Optional: Restricts tool access to a specific list of datasets. # - "my_dataset_1" # - "other_project.my_dataset_2" ``` Initialize a BigQuery source that uses the client's access token: ```yaml sources: my-bigquery-client-auth-source: kind: "bigquery" project: "my-project-id" useClientOAuth: true # location: "US" # Optional: Specifies the location for query jobs. # writeMode: "allowed" # One of: allowed, blocked, protected. Defaults to "allowed". # allowedDatasets: # Optional: Restricts tool access to a specific list of datasets. # - "my_dataset_1" # - "other_project.my_dataset_2" ``` ## Reference | **field** | **type** | **required** | **description** | |-----------------|:--------:|:------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | kind | string | true | Must be "bigquery". | | project | string | true | Id of the Google Cloud project to use for billing and as the default project for BigQuery resources. | | location | string | false | Specifies the location (e.g., 'us', 'asia-northeast1') in which to run the query job. This location must match the location of any tables referenced in the query. Defaults to the table's location or 'US' if the location cannot be determined. [Learn More](https://cloud.google.com/bigquery/docs/locations) | | writeMode | string | false | Controls the write behavior for tools. `allowed` (default): All queries are permitted. `blocked`: Only `SELECT` statements are allowed for the `bigquery-execute-sql` tool. `protected`: Enables session-based execution where all tools associated with this source instance share the same [BigQuery session](https://cloud.google.com/bigquery/docs/sessions-intro). This allows for stateful operations using temporary tables (e.g., `CREATE TEMP TABLE`). For `bigquery-execute-sql`, `SELECT` statements can be used on all tables, but write operations are restricted to the session's temporary dataset. For tools like `bigquery-sql`, `bigquery-forecast`, and `bigquery-analyze-contribution`, the `writeMode` restrictions do not apply, but they will operate within the shared session. **Note:** The `protected` mode cannot be used with `useClientOAuth: true`. It is also not recommended for multi-user server environments, as all users would share the same session. A session is terminated automatically after 24 hours of inactivity or after 7 days, whichever comes first. A new session is created on the next request, and any temporary data from the previous session will be lost. | | allowedDatasets | []string | false | An optional list of dataset IDs that tools using this source are allowed to access. If provided, any tool operation attempting to access a dataset not in this list will be rejected. To enforce this, two types of operations are also disallowed: 1) Dataset-level operations (e.g., `CREATE SCHEMA`), and 2) operations where table access cannot be statically analyzed (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`). If a single dataset is provided, it will be treated as the default for prebuilt tools. | | useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. | ``` -------------------------------------------------------------------------------- /internal/tools/neo4j/neo4jexecutecypher/classifier/classifier_test.go: -------------------------------------------------------------------------------- ```go // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package classifier import ( "reflect" "sort" "testing" ) // assertElementsMatch checks if two string slices have the same elements, ignoring order. // It serves as a replacement for testify's assert.ElementsMatch. func assertElementsMatch(t *testing.T, expected, actual []string, msg string) { // t.Helper() marks this function as a test helper. // When t.Errorf is called from this function, the line number of the calling code is reported, not the line number inside this helper. t.Helper() if len(expected) == 0 && len(actual) == 0 { return // Both are empty or nil, they match. } // Create copies to sort, leaving the original slices unmodified. expectedCopy := make([]string, len(expected)) actualCopy := make([]string, len(actual)) copy(expectedCopy, expected) copy(actualCopy, actual) sort.Strings(expectedCopy) sort.Strings(actualCopy) // reflect.DeepEqual provides a robust comparison for complex types, including sorted slices. if !reflect.DeepEqual(expectedCopy, actualCopy) { t.Errorf("%s: \nexpected: %v\n got: %v", msg, expected, actual) } } func TestQueryClassifier_Classify(t *testing.T) { classifier := NewQueryClassifier() tests := []struct { name string query string expectedType QueryType expectedWrite []string expectedRead []string minConfidence float64 }{ // Read queries { name: "simple MATCH query", query: "MATCH (n:Person) RETURN n", expectedType: ReadQuery, expectedRead: []string{"MATCH", "RETURN"}, expectedWrite: []string{}, minConfidence: 1.0, }, { name: "complex read query", query: "MATCH (p:Person)-[:KNOWS]->(f) WHERE p.age > 30 RETURN p.name, count(f) ORDER BY p.name SKIP 10 LIMIT 5", expectedType: ReadQuery, expectedRead: []string{"MATCH", "WHERE", "RETURN", "ORDER_BY", "SKIP", "LIMIT"}, expectedWrite: []string{}, minConfidence: 1.0, }, { name: "UNION query", query: "MATCH (n:Person) RETURN n.name UNION MATCH (m:Company) RETURN m.name", expectedType: ReadQuery, expectedRead: []string{"MATCH", "RETURN", "UNION", "MATCH", "RETURN"}, expectedWrite: []string{}, minConfidence: 1.0, }, // Write queries { name: "CREATE query", query: "CREATE (n:Person {name: 'John', age: 30})", expectedType: WriteQuery, expectedWrite: []string{"CREATE"}, expectedRead: []string{}, minConfidence: 1.0, }, { name: "MERGE query", query: "MERGE (n:Person {id: 123}) ON CREATE SET n.created = timestamp()", expectedType: WriteQuery, expectedWrite: []string{"MERGE", "CREATE", "SET"}, expectedRead: []string{}, minConfidence: 1.0, }, { name: "DETACH DELETE query", query: "MATCH (n:Person) DETACH DELETE n", expectedType: WriteQuery, expectedWrite: []string{"DETACH_DELETE"}, expectedRead: []string{"MATCH"}, minConfidence: 0.9, }, // Procedure calls { name: "read procedure", query: "CALL db.labels() YIELD label RETURN label", expectedType: ReadQuery, expectedRead: []string{"RETURN", "CALL db.labels"}, expectedWrite: []string{}, minConfidence: 1.0, }, { name: "unknown procedure conservative", query: "CALL custom.procedure.doSomething()", expectedType: WriteQuery, expectedWrite: []string{"CALL custom.procedure.dosomething"}, expectedRead: []string{}, minConfidence: 0.8, }, { name: "unknown read-like procedure", query: "CALL custom.procedure.getUsers()", expectedType: ReadQuery, expectedRead: []string{"CALL custom.procedure.getusers"}, expectedWrite: []string{}, minConfidence: 1.0, }, // Subqueries { name: "read subquery", query: "CALL { MATCH (n:Person) RETURN n } RETURN n", expectedType: ReadQuery, expectedRead: []string{"MATCH", "RETURN", "RETURN"}, expectedWrite: []string{}, minConfidence: 1.0, }, { name: "write subquery", query: "CALL { CREATE (n:Person) RETURN n } RETURN n", expectedType: WriteQuery, expectedWrite: []string{"CREATE", "WRITE_IN_SUBQUERY"}, expectedRead: []string{"RETURN", "RETURN"}, minConfidence: 0.9, }, // Multiline Queries { name: "multiline read query with comments", query: ` // Find all people and their friends MATCH (p:Person)-[:KNOWS]->(f:Friend) /* Where the person is older than 25 */ WHERE p.age > 25 RETURN p.name, f.name `, expectedType: ReadQuery, expectedWrite: []string{}, expectedRead: []string{"MATCH", "WHERE", "RETURN"}, minConfidence: 1.0, }, { name: "multiline write query", query: ` MATCH (p:Person {name: 'Alice'}) CREATE (c:Company {name: 'Neo4j'}) CREATE (p)-[:WORKS_FOR]->(c) `, expectedType: WriteQuery, expectedWrite: []string{"CREATE", "CREATE"}, expectedRead: []string{"MATCH"}, minConfidence: 0.9, }, // Complex Subqueries { name: "nested read subquery", query: ` CALL { MATCH (p:Person) RETURN p } CALL { MATCH (c:Company) RETURN c } RETURN p, c `, expectedType: ReadQuery, expectedWrite: []string{}, expectedRead: []string{"MATCH", "RETURN", "MATCH", "RETURN", "RETURN"}, minConfidence: 1.0, }, { name: "subquery with write and outer read", query: ` MATCH (u:User {id: 1}) CALL { WITH u CREATE (p:Post {content: 'New post'}) CREATE (u)-[:AUTHORED]->(p) RETURN p } RETURN u.name, p.content `, expectedType: WriteQuery, expectedWrite: []string{"CREATE", "CREATE", "WRITE_IN_SUBQUERY"}, expectedRead: []string{"MATCH", "WITH", "RETURN", "RETURN"}, minConfidence: 0.9, }, { name: "subquery with read passing to outer write", query: ` CALL { MATCH (p:Product {id: 'abc'}) RETURN p } WITH p SET p.lastViewed = timestamp() `, expectedType: WriteQuery, expectedWrite: []string{"SET"}, expectedRead: []string{"MATCH", "RETURN", "WITH"}, minConfidence: 0.9, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := classifier.Classify(tt.query) if tt.expectedType != result.Type { t.Errorf("Query type mismatch: expected %v, got %v", tt.expectedType, result.Type) } if result.Confidence < tt.minConfidence { t.Errorf("Confidence too low: expected at least %f, got %f", tt.minConfidence, result.Confidence) } assertElementsMatch(t, tt.expectedWrite, result.WriteTokens, "Write tokens mismatch") assertElementsMatch(t, tt.expectedRead, result.ReadTokens, "Read tokens mismatch") }) } } func TestQueryClassifier_AbuseCases(t *testing.T) { classifier := NewQueryClassifier() tests := []struct { name string query string expectedType QueryType expectedWrite []string expectedRead []string }{ { name: "write keyword in a string literal", query: `MATCH (n) WHERE n.name = 'MERGE (m)' RETURN n`, expectedType: ReadQuery, expectedWrite: []string{}, expectedRead: []string{"MATCH", "WHERE", "RETURN"}, }, { name: "incomplete SET clause", query: `MATCH (n) SET`, expectedType: WriteQuery, expectedWrite: []string{"SET"}, expectedRead: []string{"MATCH"}, }, { name: "keyword as a node label", query: `MATCH (n:CREATE) RETURN n`, expectedType: ReadQuery, expectedWrite: []string{}, // 'CREATE' should be seen as an identifier, not a keyword expectedRead: []string{"MATCH", "RETURN"}, }, { name: "unbalanced parentheses", query: `MATCH (n:Person RETURN n`, expectedType: ReadQuery, expectedWrite: []string{}, expectedRead: []string{"MATCH", "RETURN"}, }, { name: "unclosed curly brace in subquery", query: `CALL { MATCH (n) CREATE (m)`, expectedType: WriteQuery, expectedWrite: []string{"CREATE", "WRITE_IN_SUBQUERY"}, expectedRead: []string{"MATCH"}, }, { name: "semicolon inside a query part", query: `MATCH (n;Person) RETURN n`, expectedType: ReadQuery, expectedWrite: []string{}, expectedRead: []string{"MATCH", "RETURN"}, }, { name: "jumbled keywords without proper syntax", query: `RETURN CREATE MATCH DELETE`, expectedType: WriteQuery, // The classifier's job is to find the tokens, not validate the syntax. // It should find both read and write tokens. expectedWrite: []string{"CREATE", "DELETE"}, expectedRead: []string{"RETURN", "MATCH"}, }, { name: "write in a nested subquery", query: ` CALL { MATCH (a) CALL { CREATE (b:Thing) } RETURN a } RETURN "done" `, expectedType: WriteQuery, expectedWrite: []string{"CREATE", "WRITE_IN_SUBQUERY"}, expectedRead: []string{"MATCH", "RETURN", "RETURN"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // This defer-recover block ensures the test fails gracefully if the Classify function panics, // which was the goal of the original assert.NotPanics call. defer func() { if r := recover(); r != nil { t.Fatalf("The code panicked on test '%s': %v", tt.name, r) } }() result := classifier.Classify(tt.query) if tt.expectedType != result.Type { t.Errorf("Query type mismatch: expected %v, got %v", tt.expectedType, result.Type) } if tt.expectedWrite != nil { assertElementsMatch(t, tt.expectedWrite, result.WriteTokens, "Write tokens mismatch") } if tt.expectedRead != nil { assertElementsMatch(t, tt.expectedRead, result.ReadTokens, "Read tokens mismatch") } }) } } func TestNormalizeQuery(t *testing.T) { classifier := NewQueryClassifier() t.Run("single line comment", func(t *testing.T) { input := "MATCH (n) // comment\nRETURN n" expected := "MATCH (n) RETURN n" result := classifier.normalizeQuery(input) if expected != result { t.Errorf("normalizeQuery failed:\nexpected: %q\n got: %q", expected, result) } }) } ``` -------------------------------------------------------------------------------- /internal/tools/firestore/util/converter_test.go: -------------------------------------------------------------------------------- ```go // Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package util import ( "bytes" "encoding/base64" "encoding/json" "strings" "testing" "time" "google.golang.org/genproto/googleapis/type/latlng" ) func TestJSONToFirestoreValue_ComplexDocument(t *testing.T) { // This is the exact JSON format provided by the user jsonData := `{ "name": { "stringValue": "Acme Corporation" }, "establishmentDate": { "timestampValue": "2000-01-15T10:30:00Z" }, "location": { "geoPointValue": { "latitude": 34.052235, "longitude": -118.243683 } }, "active": { "booleanValue": true }, "employeeCount": { "integerValue": "1500" }, "annualRevenue": { "doubleValue": 1234567.89 }, "website": { "stringValue": "https://www.acmecorp.com" }, "contactInfo": { "mapValue": { "fields": { "email": { "stringValue": "[email protected]" }, "phone": { "stringValue": "+1-555-123-4567" }, "address": { "mapValue": { "fields": { "street": { "stringValue": "123 Business Blvd" }, "city": { "stringValue": "Los Angeles" }, "state": { "stringValue": "CA" }, "zipCode": { "stringValue": "90012" } } } } } } }, "products": { "arrayValue": { "values": [ { "stringValue": "Product A" }, { "stringValue": "Product B" }, { "mapValue": { "fields": { "productName": { "stringValue": "Product C Deluxe" }, "version": { "integerValue": "2" }, "features": { "arrayValue": { "values": [ { "stringValue": "Feature X" }, { "stringValue": "Feature Y" } ] } } } } } ] } }, "notes": { "nullValue": null }, "lastUpdated": { "timestampValue": "2025-07-30T11:47:59.000Z" }, "binaryData": { "bytesValue": "SGVsbG8gV29ybGQh" } }` // Parse JSON var data interface{} err := json.Unmarshal([]byte(jsonData), &data) if err != nil { t.Fatalf("Failed to unmarshal JSON: %v", err) } // Convert to Firestore format result, err := JSONToFirestoreValue(data, nil) if err != nil { t.Fatalf("Failed to convert JSON to Firestore value: %v", err) } // Verify the result is a map resultMap, ok := result.(map[string]interface{}) if !ok { t.Fatalf("Result should be a map, got %T", result) } // Verify string values if resultMap["name"] != "Acme Corporation" { t.Errorf("Expected name 'Acme Corporation', got %v", resultMap["name"]) } if resultMap["website"] != "https://www.acmecorp.com" { t.Errorf("Expected website 'https://www.acmecorp.com', got %v", resultMap["website"]) } // Verify timestamp establishmentDate, ok := resultMap["establishmentDate"].(time.Time) if !ok { t.Fatalf("establishmentDate should be time.Time, got %T", resultMap["establishmentDate"]) } expectedDate, _ := time.Parse(time.RFC3339, "2000-01-15T10:30:00Z") if !establishmentDate.Equal(expectedDate) { t.Errorf("Expected date %v, got %v", expectedDate, establishmentDate) } // Verify geopoint location, ok := resultMap["location"].(*latlng.LatLng) if !ok { t.Fatalf("location should be *latlng.LatLng, got %T", resultMap["location"]) } if location.Latitude != 34.052235 { t.Errorf("Expected latitude 34.052235, got %v", location.Latitude) } if location.Longitude != -118.243683 { t.Errorf("Expected longitude -118.243683, got %v", location.Longitude) } // Verify boolean if resultMap["active"] != true { t.Errorf("Expected active true, got %v", resultMap["active"]) } // Verify integer (should be int64) employeeCount, ok := resultMap["employeeCount"].(int64) if !ok { t.Fatalf("employeeCount should be int64, got %T", resultMap["employeeCount"]) } if employeeCount != int64(1500) { t.Errorf("Expected employeeCount 1500, got %v", employeeCount) } // Verify double annualRevenue, ok := resultMap["annualRevenue"].(float64) if !ok { t.Fatalf("annualRevenue should be float64, got %T", resultMap["annualRevenue"]) } if annualRevenue != 1234567.89 { t.Errorf("Expected annualRevenue 1234567.89, got %v", annualRevenue) } // Verify nested map contactInfo, ok := resultMap["contactInfo"].(map[string]interface{}) if !ok { t.Fatalf("contactInfo should be a map, got %T", resultMap["contactInfo"]) } if contactInfo["email"] != "[email protected]" { t.Errorf("Expected email '[email protected]', got %v", contactInfo["email"]) } if contactInfo["phone"] != "+1-555-123-4567" { t.Errorf("Expected phone '+1-555-123-4567', got %v", contactInfo["phone"]) } // Verify nested nested map address, ok := contactInfo["address"].(map[string]interface{}) if !ok { t.Fatalf("address should be a map, got %T", contactInfo["address"]) } if address["street"] != "123 Business Blvd" { t.Errorf("Expected street '123 Business Blvd', got %v", address["street"]) } if address["city"] != "Los Angeles" { t.Errorf("Expected city 'Los Angeles', got %v", address["city"]) } if address["state"] != "CA" { t.Errorf("Expected state 'CA', got %v", address["state"]) } if address["zipCode"] != "90012" { t.Errorf("Expected zipCode '90012', got %v", address["zipCode"]) } // Verify array products, ok := resultMap["products"].([]interface{}) if !ok { t.Fatalf("products should be an array, got %T", resultMap["products"]) } if len(products) != 3 { t.Errorf("Expected 3 products, got %d", len(products)) } if products[0] != "Product A" { t.Errorf("Expected products[0] 'Product A', got %v", products[0]) } if products[1] != "Product B" { t.Errorf("Expected products[1] 'Product B', got %v", products[1]) } // Verify complex item in array product3, ok := products[2].(map[string]interface{}) if !ok { t.Fatalf("products[2] should be a map, got %T", products[2]) } if product3["productName"] != "Product C Deluxe" { t.Errorf("Expected productName 'Product C Deluxe', got %v", product3["productName"]) } version, ok := product3["version"].(int64) if !ok { t.Fatalf("version should be int64, got %T", product3["version"]) } if version != int64(2) { t.Errorf("Expected version 2, got %v", version) } features, ok := product3["features"].([]interface{}) if !ok { t.Fatalf("features should be an array, got %T", product3["features"]) } if len(features) != 2 { t.Errorf("Expected 2 features, got %d", len(features)) } if features[0] != "Feature X" { t.Errorf("Expected features[0] 'Feature X', got %v", features[0]) } if features[1] != "Feature Y" { t.Errorf("Expected features[1] 'Feature Y', got %v", features[1]) } // Verify null value if resultMap["notes"] != nil { t.Errorf("Expected notes to be nil, got %v", resultMap["notes"]) } // Verify bytes binaryData, ok := resultMap["binaryData"].([]byte) if !ok { t.Fatalf("binaryData should be []byte, got %T", resultMap["binaryData"]) } expectedBytes, _ := base64.StdEncoding.DecodeString("SGVsbG8gV29ybGQh") if !bytes.Equal(binaryData, expectedBytes) { t.Errorf("Expected bytes %v, got %v", expectedBytes, binaryData) } } func TestJSONToFirestoreValue_IntegerFromString(t *testing.T) { // Test that integerValue as string gets converted to int64 data := map[string]interface{}{ "integerValue": "1500", } result, err := JSONToFirestoreValue(data, nil) if err != nil { t.Fatalf("Failed to convert: %v", err) } intVal, ok := result.(int64) if !ok { t.Fatalf("Result should be int64, got %T", result) } if intVal != int64(1500) { t.Errorf("Expected 1500, got %v", intVal) } } func TestFirestoreValueToJSON_RoundTrip(t *testing.T) { // Test round-trip conversion original := map[string]interface{}{ "name": "Test", "count": int64(42), "price": 19.99, "active": true, "tags": []interface{}{"tag1", "tag2"}, "metadata": map[string]interface{}{ "created": time.Now(), }, "nullField": nil, } // Convert to JSON representation jsonRepresentation := FirestoreValueToJSON(original) // Verify types are simplified jsonMap, ok := jsonRepresentation.(map[string]interface{}) if !ok { t.Fatalf("Expected map, got %T", jsonRepresentation) } // Time should be converted to string metadata, ok := jsonMap["metadata"].(map[string]interface{}) if !ok { t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"]) } _, ok = metadata["created"].(string) if !ok { t.Errorf("created should be a string, got %T", metadata["created"]) } } func TestJSONToFirestoreValue_InvalidFormats(t *testing.T) { tests := []struct { name string input interface{} wantErr bool errMsg string }{ { name: "invalid integer value", input: map[string]interface{}{ "integerValue": "not-a-number", }, wantErr: true, errMsg: "invalid integer value", }, { name: "invalid timestamp", input: map[string]interface{}{ "timestampValue": "not-a-timestamp", }, wantErr: true, errMsg: "invalid timestamp format", }, { name: "invalid geopoint - missing latitude", input: map[string]interface{}{ "geoPointValue": map[string]interface{}{ "longitude": -118.243683, }, }, wantErr: true, errMsg: "invalid geopoint value format", }, { name: "invalid array format", input: map[string]interface{}{ "arrayValue": "not-an-array", }, wantErr: true, errMsg: "invalid array value format", }, { name: "invalid map format", input: map[string]interface{}{ "mapValue": "not-a-map", }, wantErr: true, errMsg: "invalid map value format", }, { name: "invalid bytes - not base64", input: map[string]interface{}{ "bytesValue": "!!!not-base64!!!", }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := JSONToFirestoreValue(tt.input, nil) if tt.wantErr { if err == nil { t.Errorf("Expected error but got none") } else if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { t.Errorf("Expected error containing '%s', got '%v'", tt.errMsg, err) } } else { if err != nil { t.Errorf("Unexpected error: %v", err) } } }) } } ```