This is page 30 of 45. Use http://codebase.md/googleapis/genai-toolbox?lines=true&page={x} to view the full context. # Directory Structure ``` ├── .ci │ ├── continuous.release.cloudbuild.yaml │ ├── generate_release_table.sh │ ├── integration.cloudbuild.yaml │ ├── quickstart_test │ │ ├── go.integration.cloudbuild.yaml │ │ ├── js.integration.cloudbuild.yaml │ │ ├── py.integration.cloudbuild.yaml │ │ ├── run_go_tests.sh │ │ ├── run_js_tests.sh │ │ ├── run_py_tests.sh │ │ └── setup_hotels_sample.sql │ ├── test_with_coverage.sh │ └── versioned.release.cloudbuild.yaml ├── .github │ ├── auto-label.yaml │ ├── blunderbuss.yml │ ├── CODEOWNERS │ ├── header-checker-lint.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.yml │ │ ├── config.yml │ │ ├── feature_request.yml │ │ └── question.yml │ ├── label-sync.yml │ ├── labels.yaml │ ├── PULL_REQUEST_TEMPLATE.md │ ├── release-please.yml │ ├── renovate.json5 │ ├── sync-repo-settings.yaml │ └── workflows │ ├── cloud_build_failure_reporter.yml │ ├── deploy_dev_docs.yaml │ ├── deploy_previous_version_docs.yaml │ ├── deploy_versioned_docs.yaml │ ├── docs_deploy.yaml │ ├── docs_preview_clean.yaml │ ├── docs_preview_deploy.yaml │ ├── lint.yaml │ ├── schedule_reporter.yml │ ├── sync-labels.yaml │ └── tests.yaml ├── .gitignore ├── .gitmodules ├── .golangci.yaml ├── .hugo │ ├── archetypes │ │ └── default.md │ ├── assets │ │ ├── icons │ │ │ └── logo.svg │ │ └── scss │ │ ├── _styles_project.scss │ │ └── _variables_project.scss │ ├── go.mod │ ├── go.sum │ ├── hugo.toml │ ├── layouts │ │ ├── _default │ │ │ └── home.releases.releases │ │ ├── index.llms-full.txt │ │ ├── index.llms.txt │ │ ├── partials │ │ │ ├── hooks │ │ │ │ └── head-end.html │ │ │ ├── navbar-version-selector.html │ │ │ ├── page-meta-links.html │ │ │ └── td │ │ │ └── render-heading.html │ │ ├── robot.txt │ │ └── shortcodes │ │ ├── include.html │ │ ├── ipynb.html │ │ └── regionInclude.html │ ├── package-lock.json │ ├── package.json │ └── static │ ├── favicons │ │ ├── android-chrome-192x192.png │ │ ├── android-chrome-512x512.png │ │ ├── apple-touch-icon.png │ │ ├── favicon-16x16.png │ │ ├── favicon-32x32.png │ │ └── favicon.ico │ └── js │ └── w3.js ├── CHANGELOG.md ├── cmd │ ├── options_test.go │ ├── options.go │ ├── root_test.go │ ├── root.go │ └── version.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DEVELOPER.md ├── Dockerfile ├── docs │ └── en │ ├── _index.md │ ├── about │ │ ├── _index.md │ │ └── faq.md │ ├── concepts │ │ ├── _index.md │ │ └── telemetry │ │ ├── index.md │ │ ├── telemetry_flow.png │ │ └── telemetry_traces.png │ ├── getting-started │ │ ├── _index.md │ │ ├── colab_quickstart.ipynb │ │ ├── configure.md │ │ ├── introduction │ │ │ ├── _index.md │ │ │ └── architecture.png │ │ ├── local_quickstart_go.md │ │ ├── local_quickstart_js.md │ │ ├── local_quickstart.md │ │ ├── mcp_quickstart │ │ │ ├── _index.md │ │ │ ├── inspector_tools.png │ │ │ └── inspector.png │ │ └── quickstart │ │ ├── go │ │ │ ├── genAI │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ ├── genkit │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ ├── langchain │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ ├── openAI │ │ │ │ ├── go.mod │ │ │ │ ├── go.sum │ │ │ │ └── quickstart.go │ │ │ └── quickstart_test.go │ │ ├── golden.txt │ │ ├── js │ │ │ ├── genAI │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ ├── genkit │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ ├── langchain │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ ├── llamaindex │ │ │ │ ├── package-lock.json │ │ │ │ ├── package.json │ │ │ │ └── quickstart.js │ │ │ └── quickstart.test.js │ │ ├── python │ │ │ ├── __init__.py │ │ │ ├── adk │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ ├── core │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ ├── langchain │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ ├── llamaindex │ │ │ │ ├── quickstart.py │ │ │ │ └── requirements.txt │ │ │ └── quickstart_test.py │ │ └── shared │ │ ├── cloud_setup.md │ │ ├── configure_toolbox.md │ │ └── database_setup.md │ ├── how-to │ │ ├── _index.md │ │ ├── connect_via_geminicli.md │ │ ├── connect_via_mcp.md │ │ ├── connect-ide │ │ │ ├── _index.md │ │ │ ├── alloydb_pg_admin_mcp.md │ │ │ ├── alloydb_pg_mcp.md │ │ │ ├── bigquery_mcp.md │ │ │ ├── cloud_sql_mssql_admin_mcp.md │ │ │ ├── cloud_sql_mssql_mcp.md │ │ │ ├── cloud_sql_mysql_admin_mcp.md │ │ │ ├── cloud_sql_mysql_mcp.md │ │ │ ├── cloud_sql_pg_admin_mcp.md │ │ │ ├── cloud_sql_pg_mcp.md │ │ │ ├── firestore_mcp.md │ │ │ ├── looker_mcp.md │ │ │ ├── mssql_mcp.md │ │ │ ├── mysql_mcp.md │ │ │ ├── neo4j_mcp.md │ │ │ ├── postgres_mcp.md │ │ │ ├── spanner_mcp.md │ │ │ └── sqlite_mcp.md │ │ ├── deploy_docker.md │ │ ├── deploy_gke.md │ │ ├── deploy_toolbox.md │ │ ├── export_telemetry.md │ │ └── toolbox-ui │ │ ├── edit-headers.gif │ │ ├── edit-headers.png │ │ ├── index.md │ │ ├── optional-param-checked.png │ │ ├── optional-param-unchecked.png │ │ ├── run-tool.gif │ │ ├── tools.png │ │ └── toolsets.png │ ├── reference │ │ ├── _index.md │ │ ├── cli.md │ │ └── prebuilt-tools.md │ ├── resources │ │ ├── _index.md │ │ ├── authServices │ │ │ ├── _index.md │ │ │ └── google.md │ │ ├── sources │ │ │ ├── _index.md │ │ │ ├── alloydb-admin.md │ │ │ ├── alloydb-pg.md │ │ │ ├── bigquery.md │ │ │ ├── bigtable.md │ │ │ ├── cassandra.md │ │ │ ├── clickhouse.md │ │ │ ├── cloud-monitoring.md │ │ │ ├── cloud-sql-admin.md │ │ │ ├── cloud-sql-mssql.md │ │ │ ├── cloud-sql-mysql.md │ │ │ ├── cloud-sql-pg.md │ │ │ ├── couchbase.md │ │ │ ├── dataplex.md │ │ │ ├── dgraph.md │ │ │ ├── firebird.md │ │ │ ├── firestore.md │ │ │ ├── http.md │ │ │ ├── looker.md │ │ │ ├── mongodb.md │ │ │ ├── mssql.md │ │ │ ├── mysql.md │ │ │ ├── neo4j.md │ │ │ ├── oceanbase.md │ │ │ ├── oracle.md │ │ │ ├── postgres.md │ │ │ ├── redis.md │ │ │ ├── spanner.md │ │ │ ├── sqlite.md │ │ │ ├── tidb.md │ │ │ ├── trino.md │ │ │ ├── valkey.md │ │ │ └── yugabytedb.md │ │ └── tools │ │ ├── _index.md │ │ ├── alloydb │ │ │ ├── _index.md │ │ │ ├── alloydb-create-cluster.md │ │ │ ├── alloydb-create-instance.md │ │ │ ├── alloydb-create-user.md │ │ │ ├── alloydb-get-cluster.md │ │ │ ├── alloydb-get-instance.md │ │ │ ├── alloydb-get-user.md │ │ │ ├── alloydb-list-clusters.md │ │ │ ├── alloydb-list-instances.md │ │ │ ├── alloydb-list-users.md │ │ │ └── alloydb-wait-for-operation.md │ │ ├── alloydbainl │ │ │ ├── _index.md │ │ │ └── alloydb-ai-nl.md │ │ ├── bigquery │ │ │ ├── _index.md │ │ │ ├── bigquery-analyze-contribution.md │ │ │ ├── bigquery-conversational-analytics.md │ │ │ ├── bigquery-execute-sql.md │ │ │ ├── bigquery-forecast.md │ │ │ ├── bigquery-get-dataset-info.md │ │ │ ├── bigquery-get-table-info.md │ │ │ ├── bigquery-list-dataset-ids.md │ │ │ ├── bigquery-list-table-ids.md │ │ │ ├── bigquery-search-catalog.md │ │ │ └── bigquery-sql.md │ │ ├── bigtable │ │ │ ├── _index.md │ │ │ └── bigtable-sql.md │ │ ├── cassandra │ │ │ ├── _index.md │ │ │ └── cassandra-cql.md │ │ ├── clickhouse │ │ │ ├── _index.md │ │ │ ├── clickhouse-execute-sql.md │ │ │ ├── clickhouse-list-databases.md │ │ │ ├── clickhouse-list-tables.md │ │ │ └── clickhouse-sql.md │ │ ├── cloudmonitoring │ │ │ ├── _index.md │ │ │ └── cloud-monitoring-query-prometheus.md │ │ ├── cloudsql │ │ │ ├── _index.md │ │ │ ├── cloudsqlcreatedatabase.md │ │ │ ├── cloudsqlcreateusers.md │ │ │ ├── cloudsqlgetinstances.md │ │ │ ├── cloudsqllistdatabases.md │ │ │ ├── cloudsqllistinstances.md │ │ │ ├── cloudsqlmssqlcreateinstance.md │ │ │ ├── cloudsqlmysqlcreateinstance.md │ │ │ ├── cloudsqlpgcreateinstances.md │ │ │ └── cloudsqlwaitforoperation.md │ │ ├── couchbase │ │ │ ├── _index.md │ │ │ └── couchbase-sql.md │ │ ├── dataform │ │ │ ├── _index.md │ │ │ └── dataform-compile-local.md │ │ ├── dataplex │ │ │ ├── _index.md │ │ │ ├── dataplex-lookup-entry.md │ │ │ ├── dataplex-search-aspect-types.md │ │ │ └── dataplex-search-entries.md │ │ ├── dgraph │ │ │ ├── _index.md │ │ │ └── dgraph-dql.md │ │ ├── firebird │ │ │ ├── _index.md │ │ │ ├── firebird-execute-sql.md │ │ │ └── firebird-sql.md │ │ ├── firestore │ │ │ ├── _index.md │ │ │ ├── firestore-add-documents.md │ │ │ ├── firestore-delete-documents.md │ │ │ ├── firestore-get-documents.md │ │ │ ├── firestore-get-rules.md │ │ │ ├── firestore-list-collections.md │ │ │ ├── firestore-query-collection.md │ │ │ ├── firestore-query.md │ │ │ ├── firestore-update-document.md │ │ │ └── firestore-validate-rules.md │ │ ├── http │ │ │ ├── _index.md │ │ │ └── http.md │ │ ├── looker │ │ │ ├── _index.md │ │ │ ├── looker-add-dashboard-element.md │ │ │ ├── looker-conversational-analytics.md │ │ │ ├── looker-get-dashboards.md │ │ │ ├── looker-get-dimensions.md │ │ │ ├── looker-get-explores.md │ │ │ ├── looker-get-filters.md │ │ │ ├── looker-get-looks.md │ │ │ ├── looker-get-measures.md │ │ │ ├── looker-get-models.md │ │ │ ├── looker-get-parameters.md │ │ │ ├── looker-health-analyze.md │ │ │ ├── looker-health-pulse.md │ │ │ ├── looker-health-vacuum.md │ │ │ ├── looker-make-dashboard.md │ │ │ ├── looker-make-look.md │ │ │ ├── looker-query-sql.md │ │ │ ├── looker-query-url.md │ │ │ ├── looker-query.md │ │ │ └── looker-run-look.md │ │ ├── mongodb │ │ │ ├── _index.md │ │ │ ├── mongodb-aggregate.md │ │ │ ├── mongodb-delete-many.md │ │ │ ├── mongodb-delete-one.md │ │ │ ├── mongodb-find-one.md │ │ │ ├── mongodb-find.md │ │ │ ├── mongodb-insert-many.md │ │ │ ├── mongodb-insert-one.md │ │ │ ├── mongodb-update-many.md │ │ │ └── mongodb-update-one.md │ │ ├── mssql │ │ │ ├── _index.md │ │ │ ├── mssql-execute-sql.md │ │ │ ├── mssql-list-tables.md │ │ │ └── mssql-sql.md │ │ ├── mysql │ │ │ ├── _index.md │ │ │ ├── mysql-execute-sql.md │ │ │ ├── mysql-list-active-queries.md │ │ │ ├── mysql-list-table-fragmentation.md │ │ │ ├── mysql-list-tables-missing-unique-indexes.md │ │ │ ├── mysql-list-tables.md │ │ │ └── mysql-sql.md │ │ ├── neo4j │ │ │ ├── _index.md │ │ │ ├── neo4j-cypher.md │ │ │ ├── neo4j-execute-cypher.md │ │ │ └── neo4j-schema.md │ │ ├── oceanbase │ │ │ ├── _index.md │ │ │ ├── oceanbase-execute-sql.md │ │ │ └── oceanbase-sql.md │ │ ├── oracle │ │ │ ├── _index.md │ │ │ ├── oracle-execute-sql.md │ │ │ └── oracle-sql.md │ │ ├── postgres │ │ │ ├── _index.md │ │ │ ├── postgres-execute-sql.md │ │ │ ├── postgres-list-active-queries.md │ │ │ ├── postgres-list-available-extensions.md │ │ │ ├── postgres-list-installed-extensions.md │ │ │ ├── postgres-list-tables.md │ │ │ └── postgres-sql.md │ │ ├── redis │ │ │ ├── _index.md │ │ │ └── redis.md │ │ ├── spanner │ │ │ ├── _index.md │ │ │ ├── spanner-execute-sql.md │ │ │ ├── spanner-list-tables.md │ │ │ └── spanner-sql.md │ │ ├── sqlite │ │ │ ├── _index.md │ │ │ ├── sqlite-execute-sql.md │ │ │ └── sqlite-sql.md │ │ ├── tidb │ │ │ ├── _index.md │ │ │ ├── tidb-execute-sql.md │ │ │ └── tidb-sql.md │ │ ├── trino │ │ │ ├── _index.md │ │ │ ├── trino-execute-sql.md │ │ │ └── trino-sql.md │ │ ├── utility │ │ │ ├── _index.md │ │ │ └── wait.md │ │ ├── valkey │ │ │ ├── _index.md │ │ │ └── valkey.md │ │ └── yuagbytedb │ │ ├── _index.md │ │ └── yugabytedb-sql.md │ ├── samples │ │ ├── _index.md │ │ ├── alloydb │ │ │ ├── _index.md │ │ │ ├── ai-nl │ │ │ │ ├── alloydb_ai_nl.ipynb │ │ │ │ └── index.md │ │ │ └── mcp_quickstart.md │ │ ├── bigquery │ │ │ ├── _index.md │ │ │ ├── colab_quickstart_bigquery.ipynb │ │ │ ├── local_quickstart.md │ │ │ └── mcp_quickstart │ │ │ ├── _index.md │ │ │ ├── inspector_tools.png │ │ │ └── inspector.png │ │ └── looker │ │ ├── _index.md │ │ ├── looker_gemini_oauth │ │ │ ├── _index.md │ │ │ ├── authenticated.png │ │ │ ├── authorize.png │ │ │ └── registration.png │ │ ├── looker_gemini.md │ │ └── looker_mcp_inspector │ │ ├── _index.md │ │ ├── inspector_tools.png │ │ └── inspector.png │ └── sdks │ ├── _index.md │ ├── go-sdk.md │ ├── js-sdk.md │ └── python-sdk.md ├── go.mod ├── go.sum ├── internal │ ├── auth │ │ ├── auth.go │ │ └── google │ │ └── google.go │ ├── log │ │ ├── handler.go │ │ ├── log_test.go │ │ ├── log.go │ │ └── logger.go │ ├── prebuiltconfigs │ │ ├── prebuiltconfigs_test.go │ │ ├── prebuiltconfigs.go │ │ └── tools │ │ ├── alloydb-postgres-admin.yaml │ │ ├── alloydb-postgres-observability.yaml │ │ ├── alloydb-postgres.yaml │ │ ├── bigquery.yaml │ │ ├── clickhouse.yaml │ │ ├── cloud-sql-mssql-admin.yaml │ │ ├── cloud-sql-mssql-observability.yaml │ │ ├── cloud-sql-mssql.yaml │ │ ├── cloud-sql-mysql-admin.yaml │ │ ├── cloud-sql-mysql-observability.yaml │ │ ├── cloud-sql-mysql.yaml │ │ ├── cloud-sql-postgres-admin.yaml │ │ ├── cloud-sql-postgres-observability.yaml │ │ ├── cloud-sql-postgres.yaml │ │ ├── dataplex.yaml │ │ ├── firestore.yaml │ │ ├── looker-conversational-analytics.yaml │ │ ├── looker.yaml │ │ ├── mssql.yaml │ │ ├── mysql.yaml │ │ ├── neo4j.yaml │ │ ├── oceanbase.yaml │ │ ├── postgres.yaml │ │ ├── spanner-postgres.yaml │ │ ├── spanner.yaml │ │ └── sqlite.yaml │ ├── server │ │ ├── api_test.go │ │ ├── api.go │ │ ├── common_test.go │ │ ├── config.go │ │ ├── mcp │ │ │ ├── jsonrpc │ │ │ │ ├── jsonrpc_test.go │ │ │ │ └── jsonrpc.go │ │ │ ├── mcp.go │ │ │ ├── util │ │ │ │ └── lifecycle.go │ │ │ ├── v20241105 │ │ │ │ ├── method.go │ │ │ │ └── types.go │ │ │ ├── v20250326 │ │ │ │ ├── method.go │ │ │ │ └── types.go │ │ │ └── v20250618 │ │ │ ├── method.go │ │ │ └── types.go │ │ ├── mcp_test.go │ │ ├── mcp.go │ │ ├── server_test.go │ │ ├── server.go │ │ ├── static │ │ │ ├── assets │ │ │ │ └── mcptoolboxlogo.png │ │ │ ├── css │ │ │ │ └── style.css │ │ │ ├── index.html │ │ │ ├── js │ │ │ │ ├── auth.js │ │ │ │ ├── loadTools.js │ │ │ │ ├── mainContent.js │ │ │ │ ├── navbar.js │ │ │ │ ├── runTool.js │ │ │ │ ├── toolDisplay.js │ │ │ │ ├── tools.js │ │ │ │ └── toolsets.js │ │ │ ├── tools.html │ │ │ └── toolsets.html │ │ ├── web_test.go │ │ └── web.go │ ├── sources │ │ ├── alloydbadmin │ │ │ ├── alloydbadmin_test.go │ │ │ └── alloydbadmin.go │ │ ├── alloydbpg │ │ │ ├── alloydb_pg_test.go │ │ │ └── alloydb_pg.go │ │ ├── bigquery │ │ │ ├── bigquery_test.go │ │ │ └── bigquery.go │ │ ├── bigtable │ │ │ ├── bigtable_test.go │ │ │ └── bigtable.go │ │ ├── cassandra │ │ │ ├── cassandra_test.go │ │ │ └── cassandra.go │ │ ├── clickhouse │ │ │ ├── clickhouse_test.go │ │ │ └── clickhouse.go │ │ ├── cloudmonitoring │ │ │ ├── cloud_monitoring_test.go │ │ │ └── cloud_monitoring.go │ │ ├── cloudsqladmin │ │ │ ├── cloud_sql_admin_test.go │ │ │ └── cloud_sql_admin.go │ │ ├── cloudsqlmssql │ │ │ ├── cloud_sql_mssql_test.go │ │ │ └── cloud_sql_mssql.go │ │ ├── cloudsqlmysql │ │ │ ├── cloud_sql_mysql_test.go │ │ │ └── cloud_sql_mysql.go │ │ ├── cloudsqlpg │ │ │ ├── cloud_sql_pg_test.go │ │ │ └── cloud_sql_pg.go │ │ ├── couchbase │ │ │ ├── couchbase_test.go │ │ │ └── couchbase.go │ │ ├── dataplex │ │ │ ├── dataplex_test.go │ │ │ └── dataplex.go │ │ ├── dgraph │ │ │ ├── dgraph_test.go │ │ │ └── dgraph.go │ │ ├── dialect.go │ │ ├── firebird │ │ │ ├── firebird_test.go │ │ │ └── firebird.go │ │ ├── firestore │ │ │ ├── firestore_test.go │ │ │ └── firestore.go │ │ ├── http │ │ │ ├── http_test.go │ │ │ └── http.go │ │ ├── ip_type.go │ │ ├── looker │ │ │ ├── looker_test.go │ │ │ └── looker.go │ │ ├── mongodb │ │ │ ├── mongodb_test.go │ │ │ └── mongodb.go │ │ ├── mssql │ │ │ ├── mssql_test.go │ │ │ └── mssql.go │ │ ├── mysql │ │ │ ├── mysql_test.go │ │ │ └── mysql.go │ │ ├── neo4j │ │ │ ├── neo4j_test.go │ │ │ └── neo4j.go │ │ ├── oceanbase │ │ │ ├── oceanbase_test.go │ │ │ └── oceanbase.go │ │ ├── oracle │ │ │ └── oracle.go │ │ ├── postgres │ │ │ ├── postgres_test.go │ │ │ └── postgres.go │ │ ├── redis │ │ │ ├── redis_test.go │ │ │ └── redis.go │ │ ├── sources.go │ │ ├── spanner │ │ │ ├── spanner_test.go │ │ │ └── spanner.go │ │ ├── sqlite │ │ │ ├── sqlite_test.go │ │ │ └── sqlite.go │ │ ├── tidb │ │ │ ├── tidb_test.go │ │ │ └── tidb.go │ │ ├── trino │ │ │ ├── trino_test.go │ │ │ └── trino.go │ │ ├── util.go │ │ ├── valkey │ │ │ ├── valkey_test.go │ │ │ └── valkey.go │ │ └── yugabytedb │ │ ├── yugabytedb_test.go │ │ └── yugabytedb.go │ ├── telemetry │ │ ├── instrumentation.go │ │ └── telemetry.go │ ├── testutils │ │ └── testutils.go │ ├── tools │ │ ├── alloydb │ │ │ ├── alloydbcreatecluster │ │ │ │ ├── alloydbcreatecluster_test.go │ │ │ │ └── alloydbcreatecluster.go │ │ │ ├── alloydbcreateinstance │ │ │ │ ├── alloydbcreateinstance_test.go │ │ │ │ └── alloydbcreateinstance.go │ │ │ ├── alloydbcreateuser │ │ │ │ ├── alloydbcreateuser_test.go │ │ │ │ └── alloydbcreateuser.go │ │ │ ├── alloydbgetcluster │ │ │ │ ├── alloydbgetcluster_test.go │ │ │ │ └── alloydbgetcluster.go │ │ │ ├── alloydbgetinstance │ │ │ │ ├── alloydbgetinstance_test.go │ │ │ │ └── alloydbgetinstance.go │ │ │ ├── alloydbgetuser │ │ │ │ ├── alloydbgetuser_test.go │ │ │ │ └── alloydbgetuser.go │ │ │ ├── alloydblistclusters │ │ │ │ ├── alloydblistclusters_test.go │ │ │ │ └── alloydblistclusters.go │ │ │ ├── alloydblistinstances │ │ │ │ ├── alloydblistinstances_test.go │ │ │ │ └── alloydblistinstances.go │ │ │ ├── alloydblistusers │ │ │ │ ├── alloydblistusers_test.go │ │ │ │ └── alloydblistusers.go │ │ │ └── alloydbwaitforoperation │ │ │ ├── alloydbwaitforoperation_test.go │ │ │ └── alloydbwaitforoperation.go │ │ ├── alloydbainl │ │ │ ├── alloydbainl_test.go │ │ │ └── alloydbainl.go │ │ ├── bigquery │ │ │ ├── bigqueryanalyzecontribution │ │ │ │ ├── bigqueryanalyzecontribution_test.go │ │ │ │ └── bigqueryanalyzecontribution.go │ │ │ ├── bigquerycommon │ │ │ │ ├── table_name_parser_test.go │ │ │ │ ├── table_name_parser.go │ │ │ │ └── util.go │ │ │ ├── bigqueryconversationalanalytics │ │ │ │ ├── bigqueryconversationalanalytics_test.go │ │ │ │ └── bigqueryconversationalanalytics.go │ │ │ ├── bigqueryexecutesql │ │ │ │ ├── bigqueryexecutesql_test.go │ │ │ │ └── bigqueryexecutesql.go │ │ │ ├── bigqueryforecast │ │ │ │ ├── bigqueryforecast_test.go │ │ │ │ └── bigqueryforecast.go │ │ │ ├── bigquerygetdatasetinfo │ │ │ │ ├── bigquerygetdatasetinfo_test.go │ │ │ │ └── bigquerygetdatasetinfo.go │ │ │ ├── bigquerygettableinfo │ │ │ │ ├── bigquerygettableinfo_test.go │ │ │ │ └── bigquerygettableinfo.go │ │ │ ├── bigquerylistdatasetids │ │ │ │ ├── bigquerylistdatasetids_test.go │ │ │ │ └── bigquerylistdatasetids.go │ │ │ ├── bigquerylisttableids │ │ │ │ ├── bigquerylisttableids_test.go │ │ │ │ └── bigquerylisttableids.go │ │ │ ├── bigquerysearchcatalog │ │ │ │ ├── bigquerysearchcatalog_test.go │ │ │ │ └── bigquerysearchcatalog.go │ │ │ └── bigquerysql │ │ │ ├── bigquerysql_test.go │ │ │ └── bigquerysql.go │ │ ├── bigtable │ │ │ ├── bigtable_test.go │ │ │ └── bigtable.go │ │ ├── cassandra │ │ │ └── cassandracql │ │ │ ├── cassandracql_test.go │ │ │ └── cassandracql.go │ │ ├── clickhouse │ │ │ ├── clickhouseexecutesql │ │ │ │ ├── clickhouseexecutesql_test.go │ │ │ │ └── clickhouseexecutesql.go │ │ │ ├── clickhouselistdatabases │ │ │ │ ├── clickhouselistdatabases_test.go │ │ │ │ └── clickhouselistdatabases.go │ │ │ ├── clickhouselisttables │ │ │ │ ├── clickhouselisttables_test.go │ │ │ │ └── clickhouselisttables.go │ │ │ └── clickhousesql │ │ │ ├── clickhousesql_test.go │ │ │ └── clickhousesql.go │ │ ├── cloudmonitoring │ │ │ ├── cloudmonitoring_test.go │ │ │ └── cloudmonitoring.go │ │ ├── cloudsql │ │ │ ├── cloudsqlcreatedatabase │ │ │ │ ├── cloudsqlcreatedatabase_test.go │ │ │ │ └── cloudsqlcreatedatabase.go │ │ │ ├── cloudsqlcreateusers │ │ │ │ ├── cloudsqlcreateusers_test.go │ │ │ │ └── cloudsqlcreateusers.go │ │ │ ├── cloudsqlgetinstances │ │ │ │ ├── cloudsqlgetinstances_test.go │ │ │ │ └── cloudsqlgetinstances.go │ │ │ ├── cloudsqllistdatabases │ │ │ │ ├── cloudsqllistdatabases_test.go │ │ │ │ └── cloudsqllistdatabases.go │ │ │ ├── cloudsqllistinstances │ │ │ │ ├── cloudsqllistinstances_test.go │ │ │ │ └── cloudsqllistinstances.go │ │ │ └── cloudsqlwaitforoperation │ │ │ ├── cloudsqlwaitforoperation_test.go │ │ │ └── cloudsqlwaitforoperation.go │ │ ├── cloudsqlmssql │ │ │ └── cloudsqlmssqlcreateinstance │ │ │ ├── cloudsqlmssqlcreateinstance_test.go │ │ │ └── cloudsqlmssqlcreateinstance.go │ │ ├── cloudsqlmysql │ │ │ └── cloudsqlmysqlcreateinstance │ │ │ ├── cloudsqlmysqlcreateinstance_test.go │ │ │ └── cloudsqlmysqlcreateinstance.go │ │ ├── cloudsqlpg │ │ │ └── cloudsqlpgcreateinstances │ │ │ ├── cloudsqlpgcreateinstances_test.go │ │ │ └── cloudsqlpgcreateinstances.go │ │ ├── common_test.go │ │ ├── common.go │ │ ├── couchbase │ │ │ ├── couchbase_test.go │ │ │ └── couchbase.go │ │ ├── dataform │ │ │ └── dataformcompilelocal │ │ │ ├── dataformcompilelocal_test.go │ │ │ └── dataformcompilelocal.go │ │ ├── dataplex │ │ │ ├── dataplexlookupentry │ │ │ │ ├── dataplexlookupentry_test.go │ │ │ │ └── dataplexlookupentry.go │ │ │ ├── dataplexsearchaspecttypes │ │ │ │ ├── dataplexsearchaspecttypes_test.go │ │ │ │ └── dataplexsearchaspecttypes.go │ │ │ └── dataplexsearchentries │ │ │ ├── dataplexsearchentries_test.go │ │ │ └── dataplexsearchentries.go │ │ ├── dgraph │ │ │ ├── dgraph_test.go │ │ │ └── dgraph.go │ │ ├── firebird │ │ │ ├── firebirdexecutesql │ │ │ │ ├── firebirdexecutesql_test.go │ │ │ │ └── firebirdexecutesql.go │ │ │ └── firebirdsql │ │ │ ├── firebirdsql_test.go │ │ │ └── firebirdsql.go │ │ ├── firestore │ │ │ ├── firestoreadddocuments │ │ │ │ ├── firestoreadddocuments_test.go │ │ │ │ └── firestoreadddocuments.go │ │ │ ├── firestoredeletedocuments │ │ │ │ ├── firestoredeletedocuments_test.go │ │ │ │ └── firestoredeletedocuments.go │ │ │ ├── firestoregetdocuments │ │ │ │ ├── firestoregetdocuments_test.go │ │ │ │ └── firestoregetdocuments.go │ │ │ ├── firestoregetrules │ │ │ │ ├── firestoregetrules_test.go │ │ │ │ └── firestoregetrules.go │ │ │ ├── firestorelistcollections │ │ │ │ ├── firestorelistcollections_test.go │ │ │ │ └── firestorelistcollections.go │ │ │ ├── firestorequery │ │ │ │ ├── firestorequery_test.go │ │ │ │ └── firestorequery.go │ │ │ ├── firestorequerycollection │ │ │ │ ├── firestorequerycollection_test.go │ │ │ │ └── firestorequerycollection.go │ │ │ ├── firestoreupdatedocument │ │ │ │ ├── firestoreupdatedocument_test.go │ │ │ │ └── firestoreupdatedocument.go │ │ │ ├── firestorevalidaterules │ │ │ │ ├── firestorevalidaterules_test.go │ │ │ │ └── firestorevalidaterules.go │ │ │ └── util │ │ │ ├── converter_test.go │ │ │ ├── converter.go │ │ │ ├── validator_test.go │ │ │ └── validator.go │ │ ├── http │ │ │ ├── http_test.go │ │ │ └── http.go │ │ ├── http_method.go │ │ ├── looker │ │ │ ├── lookeradddashboardelement │ │ │ │ ├── lookeradddashboardelement_test.go │ │ │ │ └── lookeradddashboardelement.go │ │ │ ├── lookercommon │ │ │ │ ├── lookercommon_test.go │ │ │ │ └── lookercommon.go │ │ │ ├── lookerconversationalanalytics │ │ │ │ ├── lookerconversationalanalytics_test.go │ │ │ │ └── lookerconversationalanalytics.go │ │ │ ├── lookergetdashboards │ │ │ │ ├── lookergetdashboards_test.go │ │ │ │ └── lookergetdashboards.go │ │ │ ├── lookergetdimensions │ │ │ │ ├── lookergetdimensions_test.go │ │ │ │ └── lookergetdimensions.go │ │ │ ├── lookergetexplores │ │ │ │ ├── lookergetexplores_test.go │ │ │ │ └── lookergetexplores.go │ │ │ ├── lookergetfilters │ │ │ │ ├── lookergetfilters_test.go │ │ │ │ └── lookergetfilters.go │ │ │ ├── lookergetlooks │ │ │ │ ├── lookergetlooks_test.go │ │ │ │ └── lookergetlooks.go │ │ │ ├── lookergetmeasures │ │ │ │ ├── lookergetmeasures_test.go │ │ │ │ └── lookergetmeasures.go │ │ │ ├── lookergetmodels │ │ │ │ ├── lookergetmodels_test.go │ │ │ │ └── lookergetmodels.go │ │ │ ├── lookergetparameters │ │ │ │ ├── lookergetparameters_test.go │ │ │ │ └── lookergetparameters.go │ │ │ ├── lookerhealthanalyze │ │ │ │ ├── lookerhealthanalyze_test.go │ │ │ │ └── lookerhealthanalyze.go │ │ │ ├── lookerhealthpulse │ │ │ │ ├── lookerhealthpulse_test.go │ │ │ │ └── lookerhealthpulse.go │ │ │ ├── lookerhealthvacuum │ │ │ │ ├── lookerhealthvacuum_test.go │ │ │ │ └── lookerhealthvacuum.go │ │ │ ├── lookermakedashboard │ │ │ │ ├── lookermakedashboard_test.go │ │ │ │ └── lookermakedashboard.go │ │ │ ├── lookermakelook │ │ │ │ ├── lookermakelook_test.go │ │ │ │ └── lookermakelook.go │ │ │ ├── lookerquery │ │ │ │ ├── lookerquery_test.go │ │ │ │ └── lookerquery.go │ │ │ ├── lookerquerysql │ │ │ │ ├── lookerquerysql_test.go │ │ │ │ └── lookerquerysql.go │ │ │ ├── lookerqueryurl │ │ │ │ ├── lookerqueryurl_test.go │ │ │ │ └── lookerqueryurl.go │ │ │ └── lookerrunlook │ │ │ ├── lookerrunlook_test.go │ │ │ └── lookerrunlook.go │ │ ├── mongodb │ │ │ ├── mongodbaggregate │ │ │ │ ├── mongodbaggregate_test.go │ │ │ │ └── mongodbaggregate.go │ │ │ ├── mongodbdeletemany │ │ │ │ ├── mongodbdeletemany_test.go │ │ │ │ └── mongodbdeletemany.go │ │ │ ├── mongodbdeleteone │ │ │ │ ├── mongodbdeleteone_test.go │ │ │ │ └── mongodbdeleteone.go │ │ │ ├── mongodbfind │ │ │ │ ├── mongodbfind_test.go │ │ │ │ └── mongodbfind.go │ │ │ ├── mongodbfindone │ │ │ │ ├── mongodbfindone_test.go │ │ │ │ └── mongodbfindone.go │ │ │ ├── mongodbinsertmany │ │ │ │ ├── mongodbinsertmany_test.go │ │ │ │ └── mongodbinsertmany.go │ │ │ ├── mongodbinsertone │ │ │ │ ├── mongodbinsertone_test.go │ │ │ │ └── mongodbinsertone.go │ │ │ ├── mongodbupdatemany │ │ │ │ ├── mongodbupdatemany_test.go │ │ │ │ └── mongodbupdatemany.go │ │ │ └── mongodbupdateone │ │ │ ├── mongodbupdateone_test.go │ │ │ └── mongodbupdateone.go │ │ ├── mssql │ │ │ ├── mssqlexecutesql │ │ │ │ ├── mssqlexecutesql_test.go │ │ │ │ └── mssqlexecutesql.go │ │ │ ├── mssqllisttables │ │ │ │ ├── mssqllisttables_test.go │ │ │ │ └── mssqllisttables.go │ │ │ └── mssqlsql │ │ │ ├── mssqlsql_test.go │ │ │ └── mssqlsql.go │ │ ├── mysql │ │ │ ├── mysqlcommon │ │ │ │ └── mysqlcommon.go │ │ │ ├── mysqlexecutesql │ │ │ │ ├── mysqlexecutesql_test.go │ │ │ │ └── mysqlexecutesql.go │ │ │ ├── mysqllistactivequeries │ │ │ │ ├── mysqllistactivequeries_test.go │ │ │ │ └── mysqllistactivequeries.go │ │ │ ├── mysqllisttablefragmentation │ │ │ │ ├── mysqllisttablefragmentation_test.go │ │ │ │ └── mysqllisttablefragmentation.go │ │ │ ├── mysqllisttables │ │ │ │ ├── mysqllisttables_test.go │ │ │ │ └── mysqllisttables.go │ │ │ ├── mysqllisttablesmissinguniqueindexes │ │ │ │ ├── mysqllisttablesmissinguniqueindexes_test.go │ │ │ │ └── mysqllisttablesmissinguniqueindexes.go │ │ │ └── mysqlsql │ │ │ ├── mysqlsql_test.go │ │ │ └── mysqlsql.go │ │ ├── neo4j │ │ │ ├── neo4jcypher │ │ │ │ ├── neo4jcypher_test.go │ │ │ │ └── neo4jcypher.go │ │ │ ├── neo4jexecutecypher │ │ │ │ ├── classifier │ │ │ │ │ ├── classifier_test.go │ │ │ │ │ └── classifier.go │ │ │ │ ├── neo4jexecutecypher_test.go │ │ │ │ └── neo4jexecutecypher.go │ │ │ └── neo4jschema │ │ │ ├── cache │ │ │ │ ├── cache_test.go │ │ │ │ └── cache.go │ │ │ ├── helpers │ │ │ │ ├── helpers_test.go │ │ │ │ └── helpers.go │ │ │ ├── neo4jschema_test.go │ │ │ ├── neo4jschema.go │ │ │ └── types │ │ │ └── types.go │ │ ├── oceanbase │ │ │ ├── oceanbaseexecutesql │ │ │ │ ├── oceanbaseexecutesql_test.go │ │ │ │ └── oceanbaseexecutesql.go │ │ │ └── oceanbasesql │ │ │ ├── oceanbasesql_test.go │ │ │ └── oceanbasesql.go │ │ ├── oracle │ │ │ ├── oracleexecutesql │ │ │ │ └── oracleexecutesql.go │ │ │ └── oraclesql │ │ │ └── oraclesql.go │ │ ├── parameters_test.go │ │ ├── parameters.go │ │ ├── postgres │ │ │ ├── postgresexecutesql │ │ │ │ ├── postgresexecutesql_test.go │ │ │ │ └── postgresexecutesql.go │ │ │ ├── postgreslistactivequeries │ │ │ │ ├── postgreslistactivequeries_test.go │ │ │ │ └── postgreslistactivequeries.go │ │ │ ├── postgreslistavailableextensions │ │ │ │ ├── postgreslistavailableextensions_test.go │ │ │ │ └── postgreslistavailableextensions.go │ │ │ ├── postgreslistinstalledextensions │ │ │ │ ├── postgreslistinstalledextensions_test.go │ │ │ │ └── postgreslistinstalledextensions.go │ │ │ ├── postgreslisttables │ │ │ │ ├── postgreslisttables_test.go │ │ │ │ └── postgreslisttables.go │ │ │ └── postgressql │ │ │ ├── postgressql_test.go │ │ │ └── postgressql.go │ │ ├── redis │ │ │ ├── redis_test.go │ │ │ └── redis.go │ │ ├── spanner │ │ │ ├── spannerexecutesql │ │ │ │ ├── spannerexecutesql_test.go │ │ │ │ └── spannerexecutesql.go │ │ │ ├── spannerlisttables │ │ │ │ ├── spannerlisttables_test.go │ │ │ │ └── spannerlisttables.go │ │ │ └── spannersql │ │ │ ├── spanner_test.go │ │ │ └── spannersql.go │ │ ├── sqlite │ │ │ ├── sqliteexecutesql │ │ │ │ ├── sqliteexecutesql_test.go │ │ │ │ └── sqliteexecutesql.go │ │ │ └── sqlitesql │ │ │ ├── sqlitesql_test.go │ │ │ └── sqlitesql.go │ │ ├── tidb │ │ │ ├── tidbexecutesql │ │ │ │ ├── tidbexecutesql_test.go │ │ │ │ └── tidbexecutesql.go │ │ │ └── tidbsql │ │ │ ├── tidbsql_test.go │ │ │ └── tidbsql.go │ │ ├── tools_test.go │ │ ├── tools.go │ │ ├── toolsets.go │ │ ├── trino │ │ │ ├── trinoexecutesql │ │ │ │ ├── trinoexecutesql_test.go │ │ │ │ └── trinoexecutesql.go │ │ │ └── trinosql │ │ │ ├── trinosql_test.go │ │ │ └── trinosql.go │ │ ├── utility │ │ │ └── wait │ │ │ ├── wait_test.go │ │ │ └── wait.go │ │ ├── valkey │ │ │ ├── valkey_test.go │ │ │ └── valkey.go │ │ └── yugabytedbsql │ │ ├── yugabytedbsql_test.go │ │ └── yugabytedbsql.go │ └── util │ └── util.go ├── LICENSE ├── logo.png ├── main.go ├── README.md └── tests ├── alloydb │ ├── alloydb_integration_test.go │ └── alloydb_wait_for_operation_test.go ├── alloydbainl │ └── alloydb_ai_nl_integration_test.go ├── alloydbpg │ └── alloydb_pg_integration_test.go ├── auth.go ├── bigquery │ └── bigquery_integration_test.go ├── bigtable │ └── bigtable_integration_test.go ├── cassandra │ └── cassandra_integration_test.go ├── clickhouse │ └── clickhouse_integration_test.go ├── cloudmonitoring │ └── cloud_monitoring_integration_test.go ├── cloudsql │ ├── cloud_sql_create_database_test.go │ ├── cloud_sql_create_users_test.go │ ├── cloud_sql_get_instances_test.go │ ├── cloud_sql_list_databases_test.go │ ├── cloudsql_list_instances_test.go │ └── cloudsql_wait_for_operation_test.go ├── cloudsqlmssql │ ├── cloud_sql_mssql_create_instance_integration_test.go │ └── cloud_sql_mssql_integration_test.go ├── cloudsqlmysql │ ├── cloud_sql_mysql_create_instance_integration_test.go │ └── cloud_sql_mysql_integration_test.go ├── cloudsqlpg │ ├── cloud_sql_pg_create_instances_test.go │ └── cloud_sql_pg_integration_test.go ├── common.go ├── couchbase │ └── couchbase_integration_test.go ├── dataform │ └── dataform_integration_test.go ├── dataplex │ └── dataplex_integration_test.go ├── dgraph │ └── dgraph_integration_test.go ├── firebird │ └── firebird_integration_test.go ├── firestore │ └── firestore_integration_test.go ├── http │ └── http_integration_test.go ├── looker │ └── looker_integration_test.go ├── mongodb │ └── mongodb_integration_test.go ├── mssql │ └── mssql_integration_test.go ├── mysql │ └── mysql_integration_test.go ├── neo4j │ └── neo4j_integration_test.go ├── oceanbase │ └── oceanbase_integration_test.go ├── option.go ├── oracle │ └── oracle_integration_test.go ├── postgres │ └── postgres_integration_test.go ├── redis │ └── redis_test.go ├── server.go ├── source.go ├── spanner │ └── spanner_integration_test.go ├── sqlite │ └── sqlite_integration_test.go ├── tidb │ └── tidb_integration_test.go ├── tool.go ├── trino │ └── trino_integration_test.go ├── utility │ └── wait_integration_test.go ├── valkey │ └── valkey_test.go └── yugabytedb └── yugabytedb_integration_test.go ``` # Files -------------------------------------------------------------------------------- /tests/http/http_integration_test.go: -------------------------------------------------------------------------------- ```go 1 | // Copyright 2025 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package http 16 | 17 | import ( 18 | "bytes" 19 | "context" 20 | "encoding/json" 21 | "fmt" 22 | "io" 23 | "net/http" 24 | "net/http/httptest" 25 | "reflect" 26 | "regexp" 27 | "strings" 28 | "testing" 29 | "time" 30 | 31 | "github.com/googleapis/genai-toolbox/internal/testutils" 32 | "github.com/googleapis/genai-toolbox/internal/tools" 33 | "github.com/googleapis/genai-toolbox/tests" 34 | ) 35 | 36 | var ( 37 | HttpSourceKind = "http" 38 | HttpToolKind = "http" 39 | ) 40 | 41 | func getHTTPSourceConfig(t *testing.T) map[string]any { 42 | idToken, err := tests.GetGoogleIdToken(tests.ClientId) 43 | if err != nil { 44 | t.Fatalf("error getting ID token: %s", err) 45 | } 46 | idToken = "Bearer " + idToken 47 | return map[string]any{ 48 | "kind": HttpSourceKind, 49 | "headers": map[string]string{"Authorization": idToken}, 50 | } 51 | } 52 | 53 | // handler function for the test server 54 | func multiTool(w http.ResponseWriter, r *http.Request) { 55 | path := r.URL.Path 56 | path = strings.TrimPrefix(path, "/") // Remove leading slash 57 | 58 | switch path { 59 | case "tool0": 60 | handleTool0(w, r) 61 | case "tool1": 62 | handleTool1(w, r) 63 | case "tool1id": 64 | handleTool1Id(w, r) 65 | case "tool1name": 66 | handleTool1Name(w, r) 67 | case "tool2": 68 | handleTool2(w, r) 69 | case "tool3": 70 | handleTool3(w, r) 71 | default: 72 | http.NotFound(w, r) // Return 404 for unknown paths 73 | } 74 | } 75 | 76 | // handler function for the test server 77 | func handleTool0(w http.ResponseWriter, r *http.Request) { 78 | // expect POST method 79 | if r.Method != http.MethodPost { 80 | errorMessage := fmt.Sprintf("expected POST method but got: %s", string(r.Method)) 81 | http.Error(w, errorMessage, http.StatusBadRequest) 82 | return 83 | } 84 | w.WriteHeader(http.StatusOK) 85 | response := "hello world" 86 | err := json.NewEncoder(w).Encode(response) 87 | if err != nil { 88 | http.Error(w, "Failed to encode JSON", http.StatusInternalServerError) 89 | return 90 | } 91 | } 92 | 93 | // handler function for the test server 94 | func handleTool1(w http.ResponseWriter, r *http.Request) { 95 | // expect GET method 96 | if r.Method != http.MethodGet { 97 | errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) 98 | http.Error(w, errorMessage, http.StatusBadRequest) 99 | return 100 | } 101 | // Parse request body 102 | var requestBody map[string]interface{} 103 | bodyBytes, readErr := io.ReadAll(r.Body) 104 | if readErr != nil { 105 | http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest) 106 | return 107 | } 108 | defer r.Body.Close() 109 | err := json.Unmarshal(bodyBytes, &requestBody) 110 | if err != nil { 111 | errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes)) 112 | http.Error(w, errorMessage, http.StatusBadRequest) 113 | return 114 | } 115 | 116 | // Extract name 117 | name, ok := requestBody["name"].(string) 118 | if !ok || name == "" { 119 | http.Error(w, "Bad Request: Missing or invalid name", http.StatusBadRequest) 120 | return 121 | } 122 | 123 | if name == "Alice" { 124 | response := `[{"id":1,"name":"Alice"},{"id":3,"name":"Sid"}]` 125 | _, err := w.Write([]byte(response)) 126 | if err != nil { 127 | http.Error(w, "Failed to write response", http.StatusInternalServerError) 128 | } 129 | return 130 | } 131 | 132 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 133 | } 134 | 135 | // handler function for the test server 136 | func handleTool1Id(w http.ResponseWriter, r *http.Request) { 137 | // expect GET method 138 | if r.Method != http.MethodGet { 139 | errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) 140 | http.Error(w, errorMessage, http.StatusBadRequest) 141 | return 142 | } 143 | 144 | id := r.URL.Query().Get("id") 145 | if id == "4" { 146 | response := `[{"id":4,"name":null}]` 147 | _, err := w.Write([]byte(response)) 148 | if err != nil { 149 | http.Error(w, "Failed to write response", http.StatusInternalServerError) 150 | } 151 | return 152 | } 153 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 154 | } 155 | 156 | // handler function for the test server 157 | func handleTool1Name(w http.ResponseWriter, r *http.Request) { 158 | // expect GET method 159 | if r.Method != http.MethodGet { 160 | errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) 161 | http.Error(w, errorMessage, http.StatusBadRequest) 162 | return 163 | } 164 | 165 | name := r.URL.Query().Get("name") 166 | if name == "" { 167 | response := "null" 168 | _, err := w.Write([]byte(response)) 169 | if err != nil { 170 | http.Error(w, "Failed to write response", http.StatusInternalServerError) 171 | } 172 | return 173 | } 174 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 175 | } 176 | 177 | // handler function for the test server 178 | func handleTool2(w http.ResponseWriter, r *http.Request) { 179 | // expect GET method 180 | if r.Method != http.MethodGet { 181 | errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) 182 | http.Error(w, errorMessage, http.StatusBadRequest) 183 | return 184 | } 185 | email := r.URL.Query().Get("email") 186 | if email != "" { 187 | response := `[{"name":"Alice"}]` 188 | _, err := w.Write([]byte(response)) 189 | if err != nil { 190 | http.Error(w, "Failed to write response", http.StatusInternalServerError) 191 | } 192 | return 193 | } 194 | 195 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 196 | } 197 | 198 | // handler function for the test server 199 | func handleTool3(w http.ResponseWriter, r *http.Request) { 200 | // expect GET method 201 | if r.Method != http.MethodGet { 202 | errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method)) 203 | http.Error(w, errorMessage, http.StatusBadRequest) 204 | return 205 | } 206 | 207 | // Check request headers 208 | expectedHeaders := map[string]string{ 209 | "Content-Type": "application/json", 210 | "X-Custom-Header": "example", 211 | "X-Other-Header": "test", 212 | } 213 | for header, expectedValue := range expectedHeaders { 214 | if r.Header.Get(header) != expectedValue { 215 | errorMessage := fmt.Sprintf("Bad Request: Missing or incorrect header: %s", header) 216 | http.Error(w, errorMessage, http.StatusBadRequest) 217 | return 218 | } 219 | } 220 | 221 | // Check query parameters 222 | expectedQueryParams := map[string][]string{ 223 | "id": []string{"2", "1", "3"}, 224 | "country": []string{"US"}, 225 | } 226 | query := r.URL.Query() 227 | for param, expectedValueSlice := range expectedQueryParams { 228 | values, ok := query[param] 229 | if ok { 230 | if !reflect.DeepEqual(expectedValueSlice, values) { 231 | errorMessage := fmt.Sprintf("Bad Request: Incorrect query parameter: %s, actual: %s", param, query[param]) 232 | http.Error(w, errorMessage, http.StatusBadRequest) 233 | return 234 | } 235 | } else { 236 | errorMessage := fmt.Sprintf("Bad Request: Missing query parameter: %s, actual: %s", param, query[param]) 237 | http.Error(w, errorMessage, http.StatusBadRequest) 238 | return 239 | } 240 | } 241 | 242 | // Parse request body 243 | var requestBody map[string]interface{} 244 | bodyBytes, readErr := io.ReadAll(r.Body) 245 | if readErr != nil { 246 | http.Error(w, "Bad Request: Failed to read request body", http.StatusBadRequest) 247 | return 248 | } 249 | defer r.Body.Close() 250 | err := json.Unmarshal(bodyBytes, &requestBody) 251 | if err != nil { 252 | errorMessage := fmt.Sprintf("Bad Request: Error unmarshalling request body: %s, Raw body: %s", err, string(bodyBytes)) 253 | http.Error(w, errorMessage, http.StatusBadRequest) 254 | return 255 | } 256 | 257 | // Check request body 258 | expectedBody := map[string]interface{}{ 259 | "place": "zoo", 260 | "animals": []any{"rabbit", "ostrich", "whale"}, 261 | } 262 | 263 | if !reflect.DeepEqual(requestBody, expectedBody) { 264 | errorMessage := fmt.Sprintf("Bad Request: Incorrect request body. Expected: %v, Got: %v", expectedBody, requestBody) 265 | http.Error(w, errorMessage, http.StatusBadRequest) 266 | return 267 | } 268 | 269 | response := "hello world" 270 | err = json.NewEncoder(w).Encode(response) 271 | if err != nil { 272 | http.Error(w, "Failed to encode JSON", http.StatusInternalServerError) 273 | return 274 | } 275 | } 276 | 277 | func TestHttpToolEndpoints(t *testing.T) { 278 | // start a test server 279 | server := httptest.NewServer(http.HandlerFunc(multiTool)) 280 | defer server.Close() 281 | 282 | sourceConfig := getHTTPSourceConfig(t) 283 | sourceConfig["baseUrl"] = server.URL 284 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 285 | defer cancel() 286 | 287 | var args []string 288 | 289 | toolsFile := getHTTPToolsConfig(sourceConfig, HttpToolKind) 290 | cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) 291 | if err != nil { 292 | t.Fatalf("command initialization returned an error: %s", err) 293 | } 294 | defer cleanup() 295 | 296 | waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 297 | defer cancel() 298 | out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) 299 | if err != nil { 300 | t.Logf("toolbox command logs: \n%s", out) 301 | t.Fatalf("toolbox didn't start successfully: %s", err) 302 | } 303 | 304 | // Run tests 305 | tests.RunToolGetTest(t) 306 | tests.RunToolInvokeTest(t, `"hello world"`, tests.DisableArrayTest()) 307 | runAdvancedHTTPInvokeTest(t) 308 | } 309 | 310 | // runToolInvoke runs the tool invoke endpoint 311 | func runAdvancedHTTPInvokeTest(t *testing.T) { 312 | // Test HTTP tool invoke endpoint 313 | invokeTcs := []struct { 314 | name string 315 | api string 316 | requestHeader map[string]string 317 | requestBody io.Reader 318 | want string 319 | isErr bool 320 | }{ 321 | { 322 | name: "invoke my-advanced-tool", 323 | api: "http://127.0.0.1:5000/api/tool/my-advanced-tool/invoke", 324 | requestHeader: map[string]string{}, 325 | requestBody: bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 3, "path": "tool3", "country": "US", "X-Other-Header": "test"}`)), 326 | want: `"hello world"`, 327 | isErr: false, 328 | }, 329 | { 330 | name: "invoke my-advanced-tool with wrong params", 331 | api: "http://127.0.0.1:5000/api/tool/my-advanced-tool/invoke", 332 | requestHeader: map[string]string{}, 333 | requestBody: bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 4, "path": "tool3", "country": "US", "X-Other-Header": "test"}`)), 334 | isErr: true, 335 | }, 336 | } 337 | for _, tc := range invokeTcs { 338 | t.Run(tc.name, func(t *testing.T) { 339 | // Send Tool invocation request 340 | req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) 341 | if err != nil { 342 | t.Fatalf("unable to create request: %s", err) 343 | } 344 | req.Header.Add("Content-type", "application/json") 345 | for k, v := range tc.requestHeader { 346 | req.Header.Add(k, v) 347 | } 348 | resp, err := http.DefaultClient.Do(req) 349 | if err != nil { 350 | t.Fatalf("unable to send request: %s", err) 351 | } 352 | defer resp.Body.Close() 353 | 354 | if resp.StatusCode != http.StatusOK { 355 | if tc.isErr == true { 356 | return 357 | } 358 | bodyBytes, _ := io.ReadAll(resp.Body) 359 | t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) 360 | } 361 | 362 | // Check response body 363 | var body map[string]interface{} 364 | err = json.NewDecoder(resp.Body).Decode(&body) 365 | if err != nil { 366 | t.Fatalf("error parsing response body") 367 | } 368 | got, ok := body["result"].(string) 369 | if !ok { 370 | t.Fatalf("unable to find result in response body") 371 | } 372 | 373 | if got != tc.want { 374 | t.Fatalf("unexpected value: got %q, want %q", got, tc.want) 375 | } 376 | }) 377 | } 378 | } 379 | 380 | // getHTTPToolsConfig returns a mock HTTP tool's config file 381 | func getHTTPToolsConfig(sourceConfig map[string]any, toolKind string) map[string]any { 382 | // Write config into a file and pass it to command 383 | otherSourceConfig := make(map[string]any) 384 | for k, v := range sourceConfig { 385 | otherSourceConfig[k] = v 386 | } 387 | otherSourceConfig["headers"] = map[string]string{"X-Custom-Header": "unexpected", "Content-Type": "application/json"} 388 | otherSourceConfig["queryParams"] = map[string]any{"id": 1, "name": "Sid"} 389 | 390 | toolsFile := map[string]any{ 391 | "sources": map[string]any{ 392 | "my-instance": sourceConfig, 393 | "other-instance": otherSourceConfig, 394 | }, 395 | "authServices": map[string]any{ 396 | "my-google-auth": map[string]any{ 397 | "kind": "google", 398 | "clientId": tests.ClientId, 399 | }, 400 | }, 401 | "tools": map[string]any{ 402 | "my-simple-tool": map[string]any{ 403 | "kind": toolKind, 404 | "path": "/tool0", 405 | "method": "POST", 406 | "source": "my-instance", 407 | "requestBody": "{}", 408 | "description": "Simple tool to test end to end functionality.", 409 | }, 410 | "my-tool": map[string]any{ 411 | "kind": toolKind, 412 | "source": "my-instance", 413 | "method": "GET", 414 | "path": "/tool1", 415 | "description": "some description", 416 | "queryParams": []tools.Parameter{ 417 | tools.NewIntParameter("id", "user ID")}, 418 | "requestBody": `{ 419 | "age": 36, 420 | "name": "{{.name}}" 421 | } 422 | `, 423 | "bodyParams": []tools.Parameter{tools.NewStringParameter("name", "user name")}, 424 | "headers": map[string]string{"Content-Type": "application/json"}, 425 | }, 426 | "my-tool-by-id": map[string]any{ 427 | "kind": toolKind, 428 | "source": "my-instance", 429 | "method": "GET", 430 | "path": "/tool1id", 431 | "description": "some description", 432 | "queryParams": []tools.Parameter{ 433 | tools.NewIntParameter("id", "user ID")}, 434 | "headers": map[string]string{"Content-Type": "application/json"}, 435 | }, 436 | "my-tool-by-name": map[string]any{ 437 | "kind": toolKind, 438 | "source": "my-instance", 439 | "method": "GET", 440 | "path": "/tool1name", 441 | "description": "some description", 442 | "queryParams": []tools.Parameter{ 443 | tools.NewStringParameterWithRequired("name", "user name", false)}, 444 | "headers": map[string]string{"Content-Type": "application/json"}, 445 | }, 446 | "my-auth-tool": map[string]any{ 447 | "kind": toolKind, 448 | "source": "my-instance", 449 | "method": "GET", 450 | "path": "/tool2", 451 | "description": "some description", 452 | "requestBody": "{}", 453 | "queryParams": []tools.Parameter{ 454 | tools.NewStringParameterWithAuth("email", "some description", 455 | []tools.ParamAuthService{{Name: "my-google-auth", Field: "email"}}), 456 | }, 457 | }, 458 | "my-auth-required-tool": map[string]any{ 459 | "kind": toolKind, 460 | "source": "my-instance", 461 | "method": "POST", 462 | "path": "/tool0", 463 | "description": "some description", 464 | "requestBody": "{}", 465 | "authRequired": []string{"my-google-auth"}, 466 | }, 467 | "my-advanced-tool": map[string]any{ 468 | "kind": toolKind, 469 | "source": "other-instance", 470 | "method": "get", 471 | "path": "/{{.path}}?id=2", 472 | "description": "some description", 473 | "headers": map[string]string{ 474 | "X-Custom-Header": "example", 475 | }, 476 | "pathParams": []tools.Parameter{ 477 | &tools.StringParameter{ 478 | CommonParameter: tools.CommonParameter{Name: "path", Type: "string", Desc: "path param"}, 479 | }, 480 | }, 481 | "queryParams": []tools.Parameter{ 482 | tools.NewIntParameter("id", "user ID"), tools.NewStringParameter("country", "country")}, 483 | "requestBody": `{ 484 | "place": "zoo", 485 | "animals": {{json .animalArray }} 486 | } 487 | `, 488 | "bodyParams": []tools.Parameter{tools.NewArrayParameter("animalArray", "animals in the zoo", tools.NewStringParameter("animals", "desc"))}, 489 | "headerParams": []tools.Parameter{tools.NewStringParameter("X-Other-Header", "custom header")}, 490 | }, 491 | }, 492 | } 493 | return toolsFile 494 | } 495 | ``` -------------------------------------------------------------------------------- /internal/server/mcp.go: -------------------------------------------------------------------------------- ```go 1 | // Copyright 2025 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package server 16 | 17 | import ( 18 | "bufio" 19 | "bytes" 20 | "context" 21 | "encoding/json" 22 | "errors" 23 | "fmt" 24 | "io" 25 | "net/http" 26 | "strings" 27 | "sync" 28 | "time" 29 | 30 | "github.com/go-chi/chi/v5" 31 | "github.com/go-chi/chi/v5/middleware" 32 | "github.com/go-chi/render" 33 | "github.com/google/uuid" 34 | "github.com/googleapis/genai-toolbox/internal/server/mcp" 35 | "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" 36 | mcputil "github.com/googleapis/genai-toolbox/internal/server/mcp/util" 37 | v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105" 38 | v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326" 39 | "github.com/googleapis/genai-toolbox/internal/tools" 40 | "github.com/googleapis/genai-toolbox/internal/util" 41 | "go.opentelemetry.io/otel/attribute" 42 | "go.opentelemetry.io/otel/codes" 43 | "go.opentelemetry.io/otel/metric" 44 | ) 45 | 46 | type sseSession struct { 47 | writer http.ResponseWriter 48 | flusher http.Flusher 49 | done chan struct{} 50 | eventQueue chan string 51 | lastActive time.Time 52 | } 53 | 54 | // sseManager manages and control access to sse sessions 55 | type sseManager struct { 56 | mu sync.Mutex 57 | sseSessions map[string]*sseSession 58 | } 59 | 60 | func (m *sseManager) get(id string) (*sseSession, bool) { 61 | m.mu.Lock() 62 | defer m.mu.Unlock() 63 | session, ok := m.sseSessions[id] 64 | session.lastActive = time.Now() 65 | return session, ok 66 | } 67 | 68 | func newSseManager(ctx context.Context) *sseManager { 69 | sseM := &sseManager{ 70 | mu: sync.Mutex{}, 71 | sseSessions: make(map[string]*sseSession), 72 | } 73 | go sseM.cleanupRoutine(ctx) 74 | return sseM 75 | } 76 | 77 | func (m *sseManager) add(id string, session *sseSession) { 78 | m.mu.Lock() 79 | defer m.mu.Unlock() 80 | m.sseSessions[id] = session 81 | session.lastActive = time.Now() 82 | } 83 | 84 | func (m *sseManager) remove(id string) { 85 | m.mu.Lock() 86 | delete(m.sseSessions, id) 87 | m.mu.Unlock() 88 | } 89 | 90 | func (m *sseManager) cleanupRoutine(ctx context.Context) { 91 | timeout := 10 * time.Minute 92 | ticker := time.NewTicker(timeout) 93 | defer ticker.Stop() 94 | 95 | for { 96 | select { 97 | case <-ctx.Done(): 98 | return 99 | case <-ticker.C: 100 | func() { 101 | m.mu.Lock() 102 | defer m.mu.Unlock() 103 | now := time.Now() 104 | for id, sess := range m.sseSessions { 105 | if now.Sub(sess.lastActive) > timeout { 106 | delete(m.sseSessions, id) 107 | } 108 | } 109 | }() 110 | } 111 | } 112 | } 113 | 114 | type stdioSession struct { 115 | protocol string 116 | server *Server 117 | reader *bufio.Reader 118 | writer io.Writer 119 | } 120 | 121 | func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession { 122 | stdioSession := &stdioSession{ 123 | server: s, 124 | reader: bufio.NewReader(stdin), 125 | writer: stdout, 126 | } 127 | return stdioSession 128 | } 129 | 130 | func (s *stdioSession) Start(ctx context.Context) error { 131 | return s.readInputStream(ctx) 132 | } 133 | 134 | // readInputStream reads requests/notifications from MCP clients through stdin 135 | func (s *stdioSession) readInputStream(ctx context.Context) error { 136 | for { 137 | if err := ctx.Err(); err != nil { 138 | return err 139 | } 140 | line, err := s.readLine(ctx) 141 | if err != nil { 142 | if err == io.EOF { 143 | return nil 144 | } 145 | return err 146 | } 147 | v, res, err := processMcpMessage(ctx, []byte(line), s.server, s.protocol, "", nil) 148 | if err != nil { 149 | // errors during the processing of message will generate a valid MCP Error response. 150 | // server can continue to run. 151 | s.server.logger.ErrorContext(ctx, err.Error()) 152 | } 153 | if v != "" { 154 | s.protocol = v 155 | } 156 | // no responses for notifications 157 | if res != nil { 158 | if err = s.write(ctx, res); err != nil { 159 | return err 160 | } 161 | } 162 | } 163 | } 164 | 165 | // readLine process each line within the input stream. 166 | func (s *stdioSession) readLine(ctx context.Context) (string, error) { 167 | readChan := make(chan string, 1) 168 | errChan := make(chan error, 1) 169 | done := make(chan struct{}) 170 | defer close(done) 171 | defer close(readChan) 172 | defer close(errChan) 173 | 174 | go func() { 175 | select { 176 | case <-done: 177 | return 178 | default: 179 | line, err := s.reader.ReadString('\n') 180 | if err != nil { 181 | select { 182 | case errChan <- err: 183 | case <-done: 184 | } 185 | return 186 | } 187 | select { 188 | case readChan <- line: 189 | case <-done: 190 | } 191 | return 192 | } 193 | }() 194 | 195 | select { 196 | // if context is cancelled, return an empty string 197 | case <-ctx.Done(): 198 | return "", ctx.Err() 199 | // return error if error is found 200 | case err := <-errChan: 201 | return "", err 202 | // return line if successful 203 | case line := <-readChan: 204 | return line, nil 205 | } 206 | } 207 | 208 | // write writes to stdout with response to client 209 | func (s *stdioSession) write(ctx context.Context, response any) error { 210 | res, _ := json.Marshal(response) 211 | 212 | _, err := fmt.Fprintf(s.writer, "%s\n", res) 213 | return err 214 | } 215 | 216 | // mcpRouter creates a router that represents the routes under /mcp 217 | func mcpRouter(s *Server) (chi.Router, error) { 218 | r := chi.NewRouter() 219 | 220 | r.Use(middleware.AllowContentType("application/json", "application/json-rpc", "application/jsonrequest")) 221 | r.Use(middleware.StripSlashes) 222 | r.Use(render.SetContentType(render.ContentTypeJSON)) 223 | 224 | r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) }) 225 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { methodNotAllowed(s, w, r) }) 226 | r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) }) 227 | r.Delete("/", func(w http.ResponseWriter, r *http.Request) {}) 228 | 229 | r.Route("/{toolsetName}", func(r chi.Router) { 230 | r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) }) 231 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { methodNotAllowed(s, w, r) }) 232 | r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) }) 233 | r.Delete("/", func(w http.ResponseWriter, r *http.Request) {}) 234 | }) 235 | 236 | return r, nil 237 | } 238 | 239 | // sseHandler handles sse initialization and message. 240 | func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) { 241 | ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse") 242 | r = r.WithContext(ctx) 243 | 244 | sessionId := uuid.New().String() 245 | toolsetName := chi.URLParam(r, "toolsetName") 246 | s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) 247 | span.SetAttributes(attribute.String("session_id", sessionId)) 248 | span.SetAttributes(attribute.String("toolset_name", toolsetName)) 249 | 250 | w.Header().Set("Content-Type", "text/event-stream") 251 | w.Header().Set("Cache-Control", "no-cache") 252 | w.Header().Set("Connection", "keep-alive") 253 | w.Header().Set("Access-Control-Allow-Origin", "*") 254 | 255 | var err error 256 | defer func() { 257 | if err != nil { 258 | span.SetStatus(codes.Error, err.Error()) 259 | } 260 | span.End() 261 | status := "success" 262 | if err != nil { 263 | status = "error" 264 | } 265 | s.instrumentation.McpSse.Add( 266 | r.Context(), 267 | 1, 268 | metric.WithAttributes(attribute.String("toolbox.toolset.name", toolsetName)), 269 | metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)), 270 | metric.WithAttributes(attribute.String("toolbox.operation.status", status)), 271 | ) 272 | }() 273 | 274 | flusher, ok := w.(http.Flusher) 275 | if !ok { 276 | err = fmt.Errorf("unable to retrieve flusher for sse") 277 | s.logger.DebugContext(ctx, err.Error()) 278 | _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) 279 | } 280 | session := &sseSession{ 281 | writer: w, 282 | flusher: flusher, 283 | done: make(chan struct{}), 284 | eventQueue: make(chan string, 100), 285 | } 286 | s.sseManager.add(sessionId, session) 287 | defer s.sseManager.remove(sessionId) 288 | 289 | // https scheme formatting if (forwarded) request is a TLS request 290 | proto := r.Header.Get("X-Forwarded-Proto") 291 | if proto == "" { 292 | if r.TLS == nil { 293 | proto = "http" 294 | } else { 295 | proto = "https" 296 | } 297 | } 298 | 299 | // send initial endpoint event 300 | toolsetURL := "" 301 | if toolsetName != "" { 302 | toolsetURL = fmt.Sprintf("/%s", toolsetName) 303 | } 304 | messageEndpoint := fmt.Sprintf("%s://%s/mcp%s?sessionId=%s", proto, r.Host, toolsetURL, sessionId) 305 | s.logger.DebugContext(ctx, fmt.Sprintf("sending endpoint event: %s", messageEndpoint)) 306 | fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", messageEndpoint) 307 | flusher.Flush() 308 | 309 | clientClose := r.Context().Done() 310 | for { 311 | select { 312 | // Ensure that only a single responses are written at once 313 | case event := <-session.eventQueue: 314 | fmt.Fprint(w, event) 315 | s.logger.DebugContext(ctx, fmt.Sprintf("sending event: %s", event)) 316 | flusher.Flush() 317 | // channel for client disconnection 318 | case <-clientClose: 319 | close(session.done) 320 | s.logger.DebugContext(ctx, "client disconnected") 321 | return 322 | } 323 | } 324 | } 325 | 326 | // methodNotAllowed handles all mcp messages. 327 | func methodNotAllowed(s *Server, w http.ResponseWriter, r *http.Request) { 328 | err := fmt.Errorf("toolbox does not support streaming in streamable HTTP transport") 329 | s.logger.DebugContext(r.Context(), err.Error()) 330 | _ = render.Render(w, r, newErrResponse(err, http.StatusMethodNotAllowed)) 331 | } 332 | 333 | // httpHandler handles all mcp messages. 334 | func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { 335 | w.Header().Set("Content-Type", "application/json") 336 | 337 | ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp") 338 | r = r.WithContext(ctx) 339 | ctx = util.WithLogger(r.Context(), s.logger) 340 | 341 | var sessionId, protocolVersion string 342 | var session *sseSession 343 | 344 | // check if client connects via sse 345 | // v2024-11-05 supports http with sse 346 | paramSessionId := r.URL.Query().Get("sessionId") 347 | if paramSessionId != "" { 348 | sessionId = paramSessionId 349 | protocolVersion = v20241105.PROTOCOL_VERSION 350 | var ok bool 351 | session, ok = s.sseManager.get(sessionId) 352 | if !ok { 353 | s.logger.DebugContext(ctx, "sse session not available") 354 | } 355 | } 356 | 357 | // check if client have `Mcp-Session-Id` header 358 | // `Mcp-Session-Id` is only set for v2025-03-26 in Toolbox 359 | headerSessionId := r.Header.Get("Mcp-Session-Id") 360 | if headerSessionId != "" { 361 | protocolVersion = v20250326.PROTOCOL_VERSION 362 | } 363 | 364 | // check if client have `MCP-Protocol-Version` header 365 | // Only supported for v2025-06-18+. 366 | headerProtocolVersion := r.Header.Get("MCP-Protocol-Version") 367 | if headerProtocolVersion != "" { 368 | if !mcp.VerifyProtocolVersion(headerProtocolVersion) { 369 | err := fmt.Errorf("invalid protocol version: %s", headerProtocolVersion) 370 | _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) 371 | return 372 | } 373 | protocolVersion = headerProtocolVersion 374 | } 375 | 376 | toolsetName := chi.URLParam(r, "toolsetName") 377 | s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) 378 | span.SetAttributes(attribute.String("toolset_name", toolsetName)) 379 | 380 | var err error 381 | defer func() { 382 | if err != nil { 383 | span.SetStatus(codes.Error, err.Error()) 384 | } 385 | span.End() 386 | 387 | status := "success" 388 | if err != nil { 389 | status = "error" 390 | } 391 | s.instrumentation.McpPost.Add( 392 | r.Context(), 393 | 1, 394 | metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)), 395 | metric.WithAttributes(attribute.String("toolbox.operation.status", status)), 396 | ) 397 | }() 398 | 399 | // Read and returns a body from io.Reader 400 | body, err := io.ReadAll(r.Body) 401 | if err != nil { 402 | // Generate a new uuid if unable to decode 403 | id := uuid.New().String() 404 | s.logger.DebugContext(ctx, err.Error()) 405 | render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil)) 406 | return 407 | } 408 | 409 | v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, r.Header) 410 | if err != nil { 411 | s.logger.DebugContext(ctx, fmt.Errorf("error processing message: %w", err).Error()) 412 | } 413 | 414 | // notifications will return empty string 415 | if res == nil { 416 | // Notifications do not expect a response 417 | // Toolbox doesn't do anything with notifications yet 418 | w.WriteHeader(http.StatusAccepted) 419 | return 420 | } 421 | 422 | // for v20250326, add the `Mcp-Session-Id` header 423 | if v == v20250326.PROTOCOL_VERSION { 424 | sessionId = uuid.New().String() 425 | w.Header().Set("Mcp-Session-Id", sessionId) 426 | } 427 | 428 | if session != nil { 429 | // queue sse event 430 | eventData, _ := json.Marshal(res) 431 | select { 432 | case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): 433 | s.logger.DebugContext(ctx, "event queue successful") 434 | case <-session.done: 435 | s.logger.DebugContext(ctx, "session is close") 436 | default: 437 | s.logger.DebugContext(ctx, "unable to add to event queue") 438 | } 439 | } 440 | if rpcResponse, ok := res.(jsonrpc.JSONRPCError); ok { 441 | code := rpcResponse.Error.Code 442 | switch code { 443 | case jsonrpc.INTERNAL_ERROR: 444 | w.WriteHeader(http.StatusInternalServerError) 445 | case jsonrpc.INVALID_REQUEST: 446 | errStr := err.Error() 447 | if errors.Is(err, tools.ErrUnauthorized) { 448 | w.WriteHeader(http.StatusUnauthorized) 449 | } else if strings.Contains(errStr, "Error 401") { 450 | w.WriteHeader(http.StatusUnauthorized) 451 | } else if strings.Contains(errStr, "Error 403") { 452 | w.WriteHeader(http.StatusForbidden) 453 | } 454 | } 455 | } 456 | 457 | // send HTTP response 458 | render.JSON(w, r, res) 459 | } 460 | 461 | // processMcpMessage process the messages received from clients 462 | func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, header http.Header) (string, any, error) { 463 | logger, err := util.LoggerFromContext(ctx) 464 | if err != nil { 465 | return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err 466 | } 467 | 468 | // Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest 469 | var baseMessage jsonrpc.BaseMessage 470 | if err = util.DecodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil { 471 | // Generate a new uuid if unable to decode 472 | id := uuid.New().String() 473 | 474 | // check if user is sending a batch request 475 | var a []any 476 | unmarshalErr := json.Unmarshal(body, &a) 477 | if unmarshalErr == nil { 478 | err = fmt.Errorf("not supporting batch requests") 479 | return "", jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err 480 | } 481 | 482 | return "", jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil), err 483 | } 484 | 485 | // Check if method is present 486 | if baseMessage.Method == "" { 487 | err = fmt.Errorf("method not found") 488 | return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err 489 | } 490 | logger.DebugContext(ctx, fmt.Sprintf("method is: %s", baseMessage.Method)) 491 | 492 | // Check for JSON-RPC 2.0 493 | if baseMessage.Jsonrpc != jsonrpc.JSONRPC_VERSION { 494 | err = fmt.Errorf("invalid json-rpc version") 495 | return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err 496 | } 497 | 498 | // Check if message is a notification 499 | if baseMessage.Id == nil { 500 | err := mcp.NotificationHandler(ctx, body) 501 | return "", nil, err 502 | } 503 | 504 | switch baseMessage.Method { 505 | case mcputil.INITIALIZE: 506 | res, v, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version) 507 | if err != nil { 508 | return "", res, err 509 | } 510 | return v, res, err 511 | default: 512 | toolset, ok := s.ResourceMgr.GetToolset(toolsetName) 513 | if !ok { 514 | err = fmt.Errorf("toolset does not exist") 515 | return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err 516 | } 517 | res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.ResourceMgr.GetToolsMap(), s.ResourceMgr.GetAuthServiceMap(), body, header) 518 | return "", res, err 519 | } 520 | } 521 | ``` -------------------------------------------------------------------------------- /internal/tools/firestore/firestorequerycollection/firestorequerycollection.go: -------------------------------------------------------------------------------- ```go 1 | // Copyright 2025 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package firestorequerycollection 16 | 17 | import ( 18 | "context" 19 | "encoding/json" 20 | "fmt" 21 | "strings" 22 | 23 | firestoreapi "cloud.google.com/go/firestore" 24 | yaml "github.com/goccy/go-yaml" 25 | "github.com/googleapis/genai-toolbox/internal/sources" 26 | firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" 27 | "github.com/googleapis/genai-toolbox/internal/tools" 28 | "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" 29 | ) 30 | 31 | // Constants for tool configuration 32 | const ( 33 | kind = "firestore-query-collection" 34 | defaultLimit = 100 35 | defaultAnalyze = false 36 | maxFilterLength = 100 // Maximum filters to prevent abuse 37 | ) 38 | 39 | // Parameter keys 40 | const ( 41 | collectionPathKey = "collectionPath" 42 | filtersKey = "filters" 43 | orderByKey = "orderBy" 44 | limitKey = "limit" 45 | analyzeQueryKey = "analyzeQuery" 46 | ) 47 | 48 | // Firestore operators 49 | var validOperators = map[string]bool{ 50 | "<": true, 51 | "<=": true, 52 | ">": true, 53 | ">=": true, 54 | "==": true, 55 | "!=": true, 56 | "array-contains": true, 57 | "array-contains-any": true, 58 | "in": true, 59 | "not-in": true, 60 | } 61 | 62 | // Error messages 63 | const ( 64 | errMissingCollectionPath = "invalid or missing '%s' parameter" 65 | errInvalidFilters = "invalid '%s' parameter; expected an array" 66 | errFilterNotString = "filter at index %d is not a string" 67 | errFilterParseFailed = "failed to parse filter at index %d: %w" 68 | errInvalidOperator = "unsupported operator: %s. Valid operators are: %v" 69 | errMissingFilterValue = "no value specified for filter on field '%s'" 70 | errOrderByParseFailed = "failed to parse orderBy: %w" 71 | errQueryExecutionFailed = "failed to execute query: %w" 72 | errTooManyFilters = "too many filters provided: %d (maximum: %d)" 73 | ) 74 | 75 | func init() { 76 | if !tools.Register(kind, newConfig) { 77 | panic(fmt.Sprintf("tool kind %q already registered", kind)) 78 | } 79 | } 80 | 81 | func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { 82 | actual := Config{Name: name} 83 | if err := decoder.DecodeContext(ctx, &actual); err != nil { 84 | return nil, err 85 | } 86 | return actual, nil 87 | } 88 | 89 | // compatibleSource defines the interface for sources that can provide a Firestore client 90 | type compatibleSource interface { 91 | FirestoreClient() *firestoreapi.Client 92 | } 93 | 94 | // validate compatible sources are still compatible 95 | var _ compatibleSource = &firestoreds.Source{} 96 | 97 | var compatibleSources = [...]string{firestoreds.SourceKind} 98 | 99 | // Config represents the configuration for the Firestore query collection tool 100 | type Config struct { 101 | Name string `yaml:"name" validate:"required"` 102 | Kind string `yaml:"kind" validate:"required"` 103 | Source string `yaml:"source" validate:"required"` 104 | Description string `yaml:"description" validate:"required"` 105 | AuthRequired []string `yaml:"authRequired"` 106 | } 107 | 108 | // validate interface 109 | var _ tools.ToolConfig = Config{} 110 | 111 | // ToolConfigKind returns the kind of tool configuration 112 | func (cfg Config) ToolConfigKind() string { 113 | return kind 114 | } 115 | 116 | // Initialize creates a new Tool instance from the configuration 117 | func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { 118 | // verify source exists 119 | rawS, ok := srcs[cfg.Source] 120 | if !ok { 121 | return nil, fmt.Errorf("no source named %q configured", cfg.Source) 122 | } 123 | 124 | // verify the source is compatible 125 | s, ok := rawS.(compatibleSource) 126 | if !ok { 127 | return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) 128 | } 129 | 130 | // Create parameters 131 | parameters := createParameters() 132 | 133 | mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) 134 | 135 | // finish tool setup 136 | t := Tool{ 137 | Name: cfg.Name, 138 | Kind: kind, 139 | Parameters: parameters, 140 | AuthRequired: cfg.AuthRequired, 141 | Client: s.FirestoreClient(), 142 | manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, 143 | mcpManifest: mcpManifest, 144 | } 145 | return t, nil 146 | } 147 | 148 | // createParameters creates the parameter definitions for the tool 149 | func createParameters() tools.Parameters { 150 | collectionPathParameter := tools.NewStringParameter( 151 | collectionPathKey, 152 | "The relative path to the Firestore collection to query (e.g., 'users' or 'users/userId/posts'). Note: This is a relative path, NOT an absolute path like 'projects/{project_id}/databases/{database_id}/documents/...'", 153 | ) 154 | 155 | filtersDescription := `Array of filter objects to apply to the query. Each filter is a JSON string with: 156 | - field: The field name to filter on 157 | - op: The operator to use ("<", "<=", ">", ">=", "==", "!=", "array-contains", "array-contains-any", "in", "not-in") 158 | - value: The value to compare against (can be string, number, boolean, or array) 159 | Example: {"field": "age", "op": ">", "value": 18}` 160 | 161 | filtersParameter := tools.NewArrayParameter( 162 | filtersKey, 163 | filtersDescription, 164 | tools.NewStringParameter("item", "JSON string representation of a filter object"), 165 | ) 166 | 167 | orderByParameter := tools.NewStringParameter( 168 | orderByKey, 169 | "JSON string specifying the field and direction to order by (e.g., {\"field\": \"name\", \"direction\": \"ASCENDING\"}). Leave empty if not specified", 170 | ) 171 | 172 | limitParameter := tools.NewIntParameterWithDefault( 173 | limitKey, 174 | defaultLimit, 175 | "The maximum number of documents to return", 176 | ) 177 | 178 | analyzeQueryParameter := tools.NewBooleanParameterWithDefault( 179 | analyzeQueryKey, 180 | defaultAnalyze, 181 | "If true, returns query explain metrics including execution statistics", 182 | ) 183 | 184 | return tools.Parameters{ 185 | collectionPathParameter, 186 | filtersParameter, 187 | orderByParameter, 188 | limitParameter, 189 | analyzeQueryParameter, 190 | } 191 | } 192 | 193 | // validate interface 194 | var _ tools.Tool = Tool{} 195 | 196 | // Tool represents the Firestore query collection tool 197 | type Tool struct { 198 | Name string `yaml:"name"` 199 | Kind string `yaml:"kind"` 200 | AuthRequired []string `yaml:"authRequired"` 201 | Parameters tools.Parameters `yaml:"parameters"` 202 | 203 | Client *firestoreapi.Client 204 | manifest tools.Manifest 205 | mcpManifest tools.McpManifest 206 | } 207 | 208 | // FilterConfig represents a filter for the query 209 | type FilterConfig struct { 210 | Field string `json:"field"` 211 | Op string `json:"op"` 212 | Value interface{} `json:"value"` 213 | } 214 | 215 | // Validate checks if the filter configuration is valid 216 | func (f *FilterConfig) Validate() error { 217 | if f.Field == "" { 218 | return fmt.Errorf("filter field cannot be empty") 219 | } 220 | 221 | if !validOperators[f.Op] { 222 | ops := make([]string, 0, len(validOperators)) 223 | for op := range validOperators { 224 | ops = append(ops, op) 225 | } 226 | return fmt.Errorf(errInvalidOperator, f.Op, ops) 227 | } 228 | 229 | if f.Value == nil { 230 | return fmt.Errorf(errMissingFilterValue, f.Field) 231 | } 232 | 233 | return nil 234 | } 235 | 236 | // OrderByConfig represents ordering configuration 237 | type OrderByConfig struct { 238 | Field string `json:"field"` 239 | Direction string `json:"direction"` 240 | } 241 | 242 | // GetDirection returns the Firestore direction constant 243 | func (o *OrderByConfig) GetDirection() firestoreapi.Direction { 244 | if strings.EqualFold(o.Direction, "DESCENDING") { 245 | return firestoreapi.Desc 246 | } 247 | return firestoreapi.Asc 248 | } 249 | 250 | // QueryResult represents a document result from the query 251 | type QueryResult struct { 252 | ID string `json:"id"` 253 | Path string `json:"path"` 254 | Data map[string]any `json:"data"` 255 | CreateTime interface{} `json:"createTime,omitempty"` 256 | UpdateTime interface{} `json:"updateTime,omitempty"` 257 | ReadTime interface{} `json:"readTime,omitempty"` 258 | } 259 | 260 | // QueryResponse represents the full response including optional metrics 261 | type QueryResponse struct { 262 | Documents []QueryResult `json:"documents"` 263 | ExplainMetrics map[string]any `json:"explainMetrics,omitempty"` 264 | } 265 | 266 | // Invoke executes the Firestore query based on the provided parameters 267 | func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { 268 | // Parse parameters 269 | queryParams, err := t.parseQueryParameters(params) 270 | if err != nil { 271 | return nil, err 272 | } 273 | 274 | // Build the query 275 | query, err := t.buildQuery(queryParams) 276 | if err != nil { 277 | return nil, err 278 | } 279 | 280 | // Execute the query and return results 281 | return t.executeQuery(ctx, query, queryParams.AnalyzeQuery) 282 | } 283 | 284 | // queryParameters holds all parsed query parameters 285 | type queryParameters struct { 286 | CollectionPath string 287 | Filters []FilterConfig 288 | OrderBy *OrderByConfig 289 | Limit int 290 | AnalyzeQuery bool 291 | } 292 | 293 | // parseQueryParameters extracts and validates parameters from the input 294 | func (t Tool) parseQueryParameters(params tools.ParamValues) (*queryParameters, error) { 295 | mapParams := params.AsMap() 296 | 297 | // Get collection path 298 | collectionPath, ok := mapParams[collectionPathKey].(string) 299 | if !ok || collectionPath == "" { 300 | return nil, fmt.Errorf(errMissingCollectionPath, collectionPathKey) 301 | } 302 | 303 | // Validate collection path 304 | if err := util.ValidateCollectionPath(collectionPath); err != nil { 305 | return nil, fmt.Errorf("invalid collection path: %w", err) 306 | } 307 | 308 | result := &queryParameters{ 309 | CollectionPath: collectionPath, 310 | Limit: defaultLimit, 311 | AnalyzeQuery: defaultAnalyze, 312 | } 313 | 314 | // Parse filters 315 | if filtersRaw, ok := mapParams[filtersKey]; ok && filtersRaw != nil { 316 | filters, err := t.parseFilters(filtersRaw) 317 | if err != nil { 318 | return nil, err 319 | } 320 | result.Filters = filters 321 | } 322 | 323 | // Parse orderBy 324 | if orderByRaw, ok := mapParams[orderByKey]; ok && orderByRaw != nil { 325 | orderBy, err := t.parseOrderBy(orderByRaw) 326 | if err != nil { 327 | return nil, err 328 | } 329 | result.OrderBy = orderBy 330 | } 331 | 332 | // Parse limit 333 | if limit, ok := mapParams[limitKey].(int); ok { 334 | result.Limit = limit 335 | } 336 | 337 | // Parse analyze 338 | if analyze, ok := mapParams[analyzeQueryKey].(bool); ok { 339 | result.AnalyzeQuery = analyze 340 | } 341 | 342 | return result, nil 343 | } 344 | 345 | // parseFilters parses and validates filter configurations 346 | func (t Tool) parseFilters(filtersRaw interface{}) ([]FilterConfig, error) { 347 | filters, ok := filtersRaw.([]any) 348 | if !ok { 349 | return nil, fmt.Errorf(errInvalidFilters, filtersKey) 350 | } 351 | 352 | if len(filters) > maxFilterLength { 353 | return nil, fmt.Errorf(errTooManyFilters, len(filters), maxFilterLength) 354 | } 355 | 356 | result := make([]FilterConfig, 0, len(filters)) 357 | for i, filterRaw := range filters { 358 | filterJSON, ok := filterRaw.(string) 359 | if !ok { 360 | return nil, fmt.Errorf(errFilterNotString, i) 361 | } 362 | 363 | var filter FilterConfig 364 | if err := json.Unmarshal([]byte(filterJSON), &filter); err != nil { 365 | return nil, fmt.Errorf(errFilterParseFailed, i, err) 366 | } 367 | 368 | if err := filter.Validate(); err != nil { 369 | return nil, fmt.Errorf("filter at index %d is invalid: %w", i, err) 370 | } 371 | 372 | result = append(result, filter) 373 | } 374 | 375 | return result, nil 376 | } 377 | 378 | // parseOrderBy parses the orderBy configuration 379 | func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) { 380 | orderByJSON, ok := orderByRaw.(string) 381 | if !ok || orderByJSON == "" { 382 | return nil, nil 383 | } 384 | 385 | var orderBy OrderByConfig 386 | if err := json.Unmarshal([]byte(orderByJSON), &orderBy); err != nil { 387 | return nil, fmt.Errorf(errOrderByParseFailed, err) 388 | } 389 | 390 | if orderBy.Field == "" { 391 | return nil, nil 392 | } 393 | 394 | return &orderBy, nil 395 | } 396 | 397 | // buildQuery constructs the Firestore query from parameters 398 | func (t Tool) buildQuery(params *queryParameters) (*firestoreapi.Query, error) { 399 | collection := t.Client.Collection(params.CollectionPath) 400 | query := collection.Query 401 | 402 | // Apply filters 403 | if len(params.Filters) > 0 { 404 | filterConditions := make([]firestoreapi.EntityFilter, 0, len(params.Filters)) 405 | for _, filter := range params.Filters { 406 | filterConditions = append(filterConditions, firestoreapi.PropertyFilter{ 407 | Path: filter.Field, 408 | Operator: filter.Op, 409 | Value: filter.Value, 410 | }) 411 | } 412 | 413 | query = query.WhereEntity(firestoreapi.AndFilter{ 414 | Filters: filterConditions, 415 | }) 416 | } 417 | 418 | // Apply ordering 419 | if params.OrderBy != nil { 420 | query = query.OrderBy(params.OrderBy.Field, params.OrderBy.GetDirection()) 421 | } 422 | 423 | // Apply limit 424 | query = query.Limit(params.Limit) 425 | 426 | // Apply analyze options 427 | if params.AnalyzeQuery { 428 | query = query.WithRunOptions(firestoreapi.ExplainOptions{ 429 | Analyze: true, 430 | }) 431 | } 432 | 433 | return &query, nil 434 | } 435 | 436 | // executeQuery runs the query and formats the results 437 | func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query, analyzeQuery bool) (any, error) { 438 | docIterator := query.Documents(ctx) 439 | docs, err := docIterator.GetAll() 440 | if err != nil { 441 | return nil, fmt.Errorf(errQueryExecutionFailed, err) 442 | } 443 | 444 | // Convert results to structured format 445 | results := make([]QueryResult, len(docs)) 446 | for i, doc := range docs { 447 | results[i] = QueryResult{ 448 | ID: doc.Ref.ID, 449 | Path: doc.Ref.Path, 450 | Data: doc.Data(), 451 | CreateTime: doc.CreateTime, 452 | UpdateTime: doc.UpdateTime, 453 | ReadTime: doc.ReadTime, 454 | } 455 | } 456 | 457 | // Return with explain metrics if requested 458 | if analyzeQuery { 459 | explainMetrics, err := t.getExplainMetrics(docIterator) 460 | if err == nil && explainMetrics != nil { 461 | response := QueryResponse{ 462 | Documents: results, 463 | ExplainMetrics: explainMetrics, 464 | } 465 | return response, nil 466 | } 467 | } 468 | 469 | // Return just the documents 470 | resultsAny := make([]any, len(results)) 471 | for i, r := range results { 472 | resultsAny[i] = r 473 | } 474 | return resultsAny, nil 475 | } 476 | 477 | // getExplainMetrics extracts explain metrics from the query iterator 478 | func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) { 479 | explainMetrics, err := docIterator.ExplainMetrics() 480 | if err != nil || explainMetrics == nil { 481 | return nil, err 482 | } 483 | 484 | metricsData := make(map[string]any) 485 | 486 | // Add plan summary if available 487 | if explainMetrics.PlanSummary != nil { 488 | planSummary := make(map[string]any) 489 | planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed 490 | metricsData["planSummary"] = planSummary 491 | } 492 | 493 | // Add execution stats if available 494 | if explainMetrics.ExecutionStats != nil { 495 | executionStats := make(map[string]any) 496 | executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned 497 | executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations 498 | 499 | if explainMetrics.ExecutionStats.ExecutionDuration != nil { 500 | executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String() 501 | } 502 | 503 | if explainMetrics.ExecutionStats.DebugStats != nil { 504 | executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats 505 | } 506 | 507 | metricsData["executionStats"] = executionStats 508 | } 509 | 510 | return metricsData, nil 511 | } 512 | 513 | // ParseParams parses and validates input parameters 514 | func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { 515 | return tools.ParseParams(t.Parameters, data, claims) 516 | } 517 | 518 | // Manifest returns the tool manifest 519 | func (t Tool) Manifest() tools.Manifest { 520 | return t.manifest 521 | } 522 | 523 | // McpManifest returns the MCP manifest 524 | func (t Tool) McpManifest() tools.McpManifest { 525 | return t.mcpManifest 526 | } 527 | 528 | // Authorized checks if the tool is authorized based on verified auth services 529 | func (t Tool) Authorized(verifiedAuthServices []string) bool { 530 | return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) 531 | } 532 | 533 | func (t Tool) RequiresClientAuthorization() bool { 534 | return false 535 | } 536 | ``` -------------------------------------------------------------------------------- /internal/prebuiltconfigs/tools/cloud-sql-mysql-observability.yaml: -------------------------------------------------------------------------------- ```yaml 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | sources: 15 | cloud-monitoring-source: 16 | kind: cloud-monitoring 17 | tools: 18 | get_system_metrics: 19 | kind: cloud-monitoring-query-prometheus 20 | source: cloud-monitoring-source 21 | description: | 22 | Fetches system level cloudmonitoring data (timeseries metrics) for a MySQL instance using a PromQL query. Take projectId and instanceId from the user for which the metrics timeseries data needs to be fetched. 23 | To use this tool, you must provide the Google Cloud `projectId` and a PromQL `query`. 24 | 25 | Generate PromQL `query` for MySQL system metrics. Use the provided metrics and rules to construct queries, Get the labels like `instance_id` from user intent. 26 | 27 | Defaults: 28 | 1. Interval: Use a default interval of `5m` for `_over_time` aggregation functions unless a different window is specified by the user. 29 | 30 | PromQL Query Examples: 31 | 1. Basic Time Series: `avg_over_time({"__name__"="cloudsql.googleapis.com/database/cpu/utilization","monitored_resource"="cloudsql_database","project_id"="my-projectId","database_id"="my-projectId:my-instanceId"}[5m])` 32 | 2. Top K: `topk(30, avg_over_time({"__name__"="cloudsql.googleapis.com/database/cpu/utilization","monitored_resource"="cloudsql_database","project_id"="my-projectId","database_id"="my-projectId:my-instanceId"}[5m]))` 33 | 3. Mean: `avg(avg_over_time({"__name__"="cloudsql.googleapis.com/database/cpu/utilization","monitored_resource"="cloudsql_database","project_id"="my-projectId","database_id"="my-projectId:my-instanceId"}[5m]))` 34 | 4. Minimum: `min(min_over_time({"__name__"="cloudsql.googleapis.com/database/cpu/utilization","monitored_resource"="cloudsql_database","project_id"="my-projectId","database_id"="my-projectId:my-instanceId"}[5m]))` 35 | 5. Maximum: `max(max_over_time({"__name__"="cloudsql.googleapis.com/database/cpu/utilization","monitored_resource"="cloudsql_database","project_id"="my-projectId","database_id"="my-projectId:my-instanceId"}[5m]))` 36 | 6. Sum: `sum(avg_over_time({"__name__"="cloudsql.googleapis.com/database/cpu/utilization","monitored_resource"="cloudsql_database","project_id"="my-projectId","database_id"="my-projectId:my-instanceId"}[5m]))` 37 | 7. Count streams: `count(avg_over_time({"__name__"="cloudsql.googleapis.com/database/cpu/utilization","monitored_resource"="cloudsql_database","project_id"="my-projectId","database_id"="my-projectId:my-instanceId"}[5m]))` 38 | 8. Percentile with groupby on database_id: `quantile by ("database_id")(0.99,avg_over_time({"__name__"="cloudsql.googleapis.com/database/cpu/utilization","monitored_resource"="cloudsql_database","project_id"="my-projectId","database_id"="my-projectId:my-instanceId"}[5m]))` 39 | 40 | Available Metrics List: metricname. description. monitored resource. labels. database_id is actually the instance id and the format is `project_id:instance_id`. 41 | 1. `cloudsql.googleapis.com/database/cpu/utilization`: Current CPU utilization as a percentage of reserved CPU. `cloudsql_database`. `database`, `project_id`, `database_id`. 42 | 2. `cloudsql.googleapis.com/database/network/connections`: Number of connections to the database instance. `cloudsql_database`. `database`, `project_id`, `database_id`. 43 | 3. `cloudsql.googleapis.com/database/network/received_bytes_count`: Delta count of bytes received through the network. `cloudsql_database`. `database`, `project_id`, `database_id`. 44 | 4. `cloudsql.googleapis.com/database/network/sent_bytes_count`: Delta count of bytes sent through the network. `cloudsql_database`. `destination`, `database`, `project_id`, `database_id`. 45 | 5. `cloudsql.googleapis.com/database/memory/components`: Memory usage for components like usage, cache, and free memory. `cloudsql_database`. `component`, `database`, `project_id`, `database_id`. 46 | 6. `cloudsql.googleapis.com/database/disk/bytes_used_by_data_type`: Data utilization in bytes. `cloudsql_database`. `data_type`, `database`, `project_id`, `database_id`. 47 | 7. `cloudsql.googleapis.com/database/disk/read_ops_count`: Delta count of data disk read IO operations. `cloudsql_database`. `database`, `project_id`, `database_id`. 48 | 8. `cloudsql.googleapis.com/database/disk/write_ops_count`: Delta count of data disk write IO operations. `cloudsql_database`. `database`, `project_id`, `database_id`. 49 | 9. `cloudsql.googleapis.com/database/mysql/queries`: Delta count of statements executed by the server. `cloudsql_database`. `database`, `project_id`, `database_id`. 50 | 10. `cloudsql.googleapis.com/database/mysql/questions`: Delta count of statements sent by the client. `cloudsql_database`. `database`, `project_id`, `database_id`. 51 | 11. `cloudsql.googleapis.com/database/mysql/received_bytes_count`: Delta count of bytes received by MySQL process. `cloudsql_database`. `database`, `project_id`, `database_id`. 52 | 12. `cloudsql.googleapis.com/database/mysql/sent_bytes_count`: Delta count of bytes sent by MySQL process. `cloudsql_database`. `database`, `project_id`, `database_id`. 53 | 13. `cloudsql.googleapis.com/database/mysql/innodb_buffer_pool_pages_dirty`: Number of unflushed pages in the InnoDB buffer pool. `cloudsql_database`. `database`, `project_id`, `database_id`. 54 | 14. `cloudsql.googleapis.com/database/mysql/innodb_buffer_pool_pages_free`: Number of unused pages in the InnoDB buffer pool. `cloudsql_database`. `database`, `project_id`, `database_id`. 55 | 15. `cloudsql.googleapis.com/database/mysql/innodb_buffer_pool_pages_total`: Total number of pages in the InnoDB buffer pool. `cloudsql_database`. `database`, `project_id`, `database_id`. 56 | 16. `cloudsql.googleapis.com/database/mysql/innodb_data_fsyncs`: Delta count of InnoDB fsync() calls. `cloudsql_database`. `database`, `project_id`, `database_id`. 57 | 17. `cloudsql.googleapis.com/database/mysql/innodb_os_log_fsyncs`: Delta count of InnoDB fsync() calls to the log file. `cloudsql_database`. `database`, `project_id`, `database_id`. 58 | 18. `cloudsql.googleapis.com/database/mysql/innodb_pages_read`: Delta count of InnoDB pages read. `cloudsql_database`. `database`, `project_id`, `database_id`. 59 | 19. `cloudsql.googleapis.com/database/mysql/innodb_pages_written`: Delta count of InnoDB pages written. `cloudsql_database`. `database`, `project_id`, `database_id`. 60 | 20. `cloudsql.googleapis.com/database/mysql/open_tables`: The number of tables that are currently open. `cloudsql_database`. `database`, `project_id`, `database_id`. 61 | 21. `cloudsql.googleapis.com/database/mysql/opened_table_count`: The number of tables opened since the last sample. `cloudsql_database`. `database`, `project_id`, `database_id`. 62 | 22. `cloudsql.googleapis.com/database/mysql/open_table_definitions`: The number of table definitions currently cached. `cloudsql_database`. `database`, `project_id`, `database_id`. 63 | 23. `cloudsql.googleapis.com/database/mysql/opened_table_definitions_count`: The number of table definitions cached since the last sample. `cloudsql_database`. `database`, `project_id`, `database_id`. 64 | 24. `cloudsql.googleapis.com/database/mysql/innodb/dictionary_memory`: Memory allocated for the InnoDB dictionary cache. `cloudsql_database`. `database`, `project_id`, `database_id`. 65 | 66 | get_query_metrics: 67 | kind: cloud-monitoring-query-prometheus 68 | source: cloud-monitoring-source 69 | description: | 70 | Fetches query level cloudmonitoring data (timeseries metrics) for queries running in Mysql instance using a PromQL query. Take projectID and instanceID from the user for which the metrics timeseries data needs to be fetched. 71 | To use this tool, you must provide the Google Cloud `projectId` and a PromQL `query`. 72 | 73 | Generate PromQL `query` for Mysql query metrics. Use the provided metrics and rules to construct queries, Get the labels like `instance_id`, `query_hash` from user intent. If query_hash is provided then use the per_query metrics. Query hash and query id are same. 74 | 75 | Defaults: 76 | 1. Interval: Use a default interval of `5m` for `_over_time` aggregation functions unless a different window is specified by the user. 77 | 78 | PromQL Query Examples: 79 | 1. Basic Time Series: `avg_over_time({"__name__"="dbinsights.googleapis.com/aggregate/execution_time","monitored_resource"="cloudsql_instance_database","project_id"="my-projectId","resource_id"="my-projectId:my-instanceId"}[5m])` 80 | 2. Top K: `topk(30, avg_over_time({"__name__"="dbinsights.googleapis.com/aggregate/execution_time","monitored_resource"="cloudsql_instance_database","project_id"="my-projectId","resource_id"="my-projectId:my-instanceId"}[5m]))` 81 | 3. Mean: `avg(avg_over_time({"__name__"="dbinsights.googleapis.com/aggregate/execution_time","monitored_resource"="cloudsql_instance_database","project_id"="my-projectId","resource_id"="my-projectId:my-instanceId"}[5m]))` 82 | 4. Minimum: `min(min_over_time({"__name__"="dbinsights.googleapis.com/aggregate/execution_time","monitored_resource"="cloudsql_instance_database","project_id"="my-projectId","resource_id"="my-projectId:my-instanceId"}[5m]))` 83 | 5. Maximum: `max(max_over_time({"__name__"="dbinsights.googleapis.com/aggregate/execution_time","monitored_resource"="cloudsql_instance_database","project_id"="my-projectId","resource_id"="my-projectId:my-instanceId"}[5m]))` 84 | 6. Sum: `sum(avg_over_time({"__name__"="dbinsights.googleapis.com/aggregate/execution_time","monitored_resource"="cloudsql_instance_database","project_id"="my-projectId","resource_id"="my-projectId:my-instanceId"}[5m]))` 85 | 7. Count streams: `count(avg_over_time({"__name__"="dbinsights.googleapis.com/aggregate/execution_time","monitored_resource"="cloudsql_instance_database","project_id"="my-projectId","resource_id"="my-projectId:my-instanceId"}[5m]))` 86 | 8. Percentile with groupby on resource_id, database: `quantile by ("resource_id","database")(0.99,avg_over_time({"__name__"="dbinsights.googleapis.com/aggregate/execution_time","monitored_resource"="cloudsql_instance_database","project_id"="my-projectId","resource_id"="my-projectId:my-instanceId"}[5m]))` 87 | 88 | Available Metrics List: metricname. description. monitored resource. labels. resource_id label format is `project_id:instance_id` which is actually instance id only. aggregate is the aggregated values for all query stats, Use aggregate metrics if query id is not provided. For perquery metrics do not fetch querystring unless specified by user specifically. Have the aggregation on query hash to avoid fetching the querystring. Do not use latency metrics for anything. 89 | 1. `dbinsights.googleapis.com/aggregate/latencies`: Cumulative query latency distribution per user and database. `cloudsql_instance_database`. `user`, `client_addr`, `database`, `project_id`, `resource_id`. 90 | 2. `dbinsights.googleapis.com/aggregate/execution_time`: Cumulative query execution time per user and database. `cloudsql_instance_database`. `user`, `client_addr`, `database`, `project_id`, `resource_id`. 91 | 3. `dbinsights.googleapis.com/aggregate/execution_count`: Total number of query executions per user and database. `cloudsql_instance_database`. `user`, `client_addr`, `database`, `project_id`, `resource_id`. 92 | 4. `dbinsights.googleapis.com/aggregate/lock_time`: Cumulative lock wait time per user and database. `cloudsql_instance_database`. `user`, `client_addr`, `lock_type`, `database`, `project_id`, `resource_id`. 93 | 5. `dbinsights.googleapis.com/aggregate/io_time`: Cumulative IO wait time per user and database. `cloudsql_instance_database`. `user`, `client_addr`, `database`, `project_id`, `resource_id`. 94 | 6. `dbinsights.googleapis.com/aggregate/row_count`: Total number of rows affected during query execution. `cloudsql_instance_database`. `user`, `client_addr`, `row_status`, `database`, `project_id`, `resource_id`. 95 | 7. `dbinsights.googleapis.com/perquery/latencies`: Cumulative query latency distribution per user, database, and query. `cloudsql_instance_database`. `querystring`, `user`, `client_addr`, `query_hash`, `database`, `project_id`, `resource_id`. 96 | 8. `dbinsights.googleapis.com/perquery/execution_time`: Cumulative query execution time per user, database, and query. `cloudsql_instance_database`. `querystring`, `user`, `client_addr`, `query_hash`, `database`, `project_id`, `resource_id`. 97 | 9. `dbinsights.googleapis.com/perquery/execution_count`: Total number of query executions per user, database, and query. `cloudsql_instance_database`. `querystring`, `user`, `client_addr`, `query_hash`, `database`, `project_id`, `resource_id`. 98 | 10. `dbinsights.googleapis.com/perquery/lock_time`: Cumulative lock wait time per user, database, and query. `cloudsql_instance_database`. `querystring`, `user`, `client_addr`, `lock_type`, `query_hash`, `database`, `project_id`, `resource_id`. 99 | 11. `dbinsights.googleapis.com/perquery/io_time`: Cumulative io wait time per user, database, and query. `cloudsql_instance_database`. `querystring`, `user`, `client_addr`, `query_hash`, `database`, `project_id`, `resource_id`. 100 | 12. `dbinsights.googleapis.com/perquery/row_count`: Total number of rows affected during query execution. `cloudsql_instance_database`. `querystring`, `user`, `client_addr`, `query_hash`, `row_status`, `database`, `project_id`, `resource_id`. 101 | 13. `dbinsights.googleapis.com/pertag/latencies`: Cumulative query latency distribution per user, database, and tag. `cloudsql_instance_database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `tag_hash`, `database`, `project_id`, `resource_id`. 102 | 14. `dbinsights.googleapis.com/pertag/execution_time`: Cumulative query execution time per user, database, and tag. `cloudsql_instance_database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `tag_hash`, `database`, `project_id`, `resource_id`. 103 | 15. `dbinsights.googleapis.com/pertag/execution_count`: Total number of query executions per user, database, and tag. `cloudsql_instance_database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `tag_hash`, `database`, `project_id`, `resource_id`. 104 | 16. `dbinsights.googleapis.com/pertag/lock_time`: Cumulative lock wait time per user, database and tag. `cloudsql_instance_database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `lock_type`, `tag_hash`, `database`, `project_id`, `resource_id`. 105 | 17. `dbinsights.googleapis.com/pertag/io_time`: Cumulative IO wait time per user, database and tag. `cloudsql_instance_database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `tag_hash`, `database`, `project_id`, `resource_id`. 106 | 18. `dbinsights.googleapis.com/pertag/row_count`: Total number of rows affected during query execution. `cloudsql_instance_database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `tag_hash`, `row_status`, `database`, `project_id`, `resource_id`. 107 | 108 | toolsets: 109 | cloud_sql_mysql_cloud_monitoring_tools: 110 | - get_system_metrics 111 | - get_query_metrics ``` -------------------------------------------------------------------------------- /internal/tools/firestore/firestorequery/firestorequery_test.go: -------------------------------------------------------------------------------- ```go 1 | // Copyright 2025 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package firestorequery_test 16 | 17 | import ( 18 | "testing" 19 | 20 | yaml "github.com/goccy/go-yaml" 21 | "github.com/google/go-cmp/cmp" 22 | "github.com/googleapis/genai-toolbox/internal/server" 23 | "github.com/googleapis/genai-toolbox/internal/testutils" 24 | "github.com/googleapis/genai-toolbox/internal/tools" 25 | "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequery" 26 | ) 27 | 28 | func TestParseFromYamlFirestoreQuery(t *testing.T) { 29 | ctx, err := testutils.ContextWithNewLogger() 30 | if err != nil { 31 | t.Fatalf("unexpected error: %s", err) 32 | } 33 | tcs := []struct { 34 | desc string 35 | in string 36 | want server.ToolConfigs 37 | }{ 38 | { 39 | desc: "basic example with parameterized collection path", 40 | in: ` 41 | tools: 42 | query_users_tool: 43 | kind: firestore-query 44 | source: my-firestore-instance 45 | description: Query users collection with parameterized path 46 | collectionPath: "users/{{.userId}}/documents" 47 | parameters: 48 | - name: userId 49 | type: string 50 | description: The user ID to query documents for 51 | required: true 52 | `, 53 | want: server.ToolConfigs{ 54 | "query_users_tool": firestorequery.Config{ 55 | Name: "query_users_tool", 56 | Kind: "firestore-query", 57 | Source: "my-firestore-instance", 58 | Description: "Query users collection with parameterized path", 59 | CollectionPath: "users/{{.userId}}/documents", 60 | AuthRequired: []string{}, 61 | Parameters: tools.Parameters{ 62 | tools.NewStringParameterWithRequired("userId", "The user ID to query documents for", true), 63 | }, 64 | }, 65 | }, 66 | }, 67 | { 68 | desc: "with parameterized filters", 69 | in: ` 70 | tools: 71 | query_products_tool: 72 | kind: firestore-query 73 | source: prod-firestore 74 | description: Query products with dynamic filters 75 | collectionPath: "products" 76 | filters: | 77 | { 78 | "and": [ 79 | {"field": "category", "op": "==", "value": {"stringValue": "{{.category}}"}}, 80 | {"field": "price", "op": "<=", "value": {"doubleValue": {{.maxPrice}}}} 81 | ] 82 | } 83 | parameters: 84 | - name: category 85 | type: string 86 | description: Product category to filter by 87 | required: true 88 | - name: maxPrice 89 | type: float 90 | description: Maximum price for products 91 | required: true 92 | `, 93 | want: server.ToolConfigs{ 94 | "query_products_tool": firestorequery.Config{ 95 | Name: "query_products_tool", 96 | Kind: "firestore-query", 97 | Source: "prod-firestore", 98 | Description: "Query products with dynamic filters", 99 | CollectionPath: "products", 100 | Filters: `{ 101 | "and": [ 102 | {"field": "category", "op": "==", "value": {"stringValue": "{{.category}}"}}, 103 | {"field": "price", "op": "<=", "value": {"doubleValue": {{.maxPrice}}}} 104 | ] 105 | } 106 | `, 107 | AuthRequired: []string{}, 108 | Parameters: tools.Parameters{ 109 | tools.NewStringParameterWithRequired("category", "Product category to filter by", true), 110 | tools.NewFloatParameterWithRequired("maxPrice", "Maximum price for products", true), 111 | }, 112 | }, 113 | }, 114 | }, 115 | { 116 | desc: "with select fields and orderBy", 117 | in: ` 118 | tools: 119 | query_orders_tool: 120 | kind: firestore-query 121 | source: orders-firestore 122 | description: Query orders with field selection 123 | collectionPath: "orders" 124 | select: 125 | - orderId 126 | - customerName 127 | - totalAmount 128 | orderBy: 129 | field: "{{.sortField}}" 130 | direction: "DESCENDING" 131 | limit: 50 132 | parameters: 133 | - name: sortField 134 | type: string 135 | description: Field to sort by 136 | required: true 137 | `, 138 | want: server.ToolConfigs{ 139 | "query_orders_tool": firestorequery.Config{ 140 | Name: "query_orders_tool", 141 | Kind: "firestore-query", 142 | Source: "orders-firestore", 143 | Description: "Query orders with field selection", 144 | CollectionPath: "orders", 145 | Select: []string{"orderId", "customerName", "totalAmount"}, 146 | OrderBy: map[string]any{ 147 | "field": "{{.sortField}}", 148 | "direction": "DESCENDING", 149 | }, 150 | Limit: "50", 151 | AuthRequired: []string{}, 152 | Parameters: tools.Parameters{ 153 | tools.NewStringParameterWithRequired("sortField", "Field to sort by", true), 154 | }, 155 | }, 156 | }, 157 | }, 158 | { 159 | desc: "with auth requirements and complex filters", 160 | in: ` 161 | tools: 162 | secure_query_tool: 163 | kind: firestore-query 164 | source: secure-firestore 165 | description: Query with authentication and complex filters 166 | collectionPath: "{{.collection}}" 167 | filters: | 168 | { 169 | "or": [ 170 | { 171 | "and": [ 172 | {"field": "status", "op": "==", "value": {"stringValue": "{{.status}}"}}, 173 | {"field": "priority", "op": ">=", "value": {"integerValue": "{{.minPriority}}"}} 174 | ] 175 | }, 176 | {"field": "urgent", "op": "==", "value": {"booleanValue": true}} 177 | ] 178 | } 179 | analyzeQuery: true 180 | authRequired: 181 | - google-auth-service 182 | - api-key-service 183 | parameters: 184 | - name: collection 185 | type: string 186 | description: Collection name to query 187 | required: true 188 | - name: status 189 | type: string 190 | description: Status to filter by 191 | required: true 192 | - name: minPriority 193 | type: integer 194 | description: Minimum priority level 195 | default: 1 196 | `, 197 | want: server.ToolConfigs{ 198 | "secure_query_tool": firestorequery.Config{ 199 | Name: "secure_query_tool", 200 | Kind: "firestore-query", 201 | Source: "secure-firestore", 202 | Description: "Query with authentication and complex filters", 203 | CollectionPath: "{{.collection}}", 204 | Filters: `{ 205 | "or": [ 206 | { 207 | "and": [ 208 | {"field": "status", "op": "==", "value": {"stringValue": "{{.status}}"}}, 209 | {"field": "priority", "op": ">=", "value": {"integerValue": "{{.minPriority}}"}} 210 | ] 211 | }, 212 | {"field": "urgent", "op": "==", "value": {"booleanValue": true}} 213 | ] 214 | } 215 | `, 216 | AnalyzeQuery: true, 217 | AuthRequired: []string{"google-auth-service", "api-key-service"}, 218 | Parameters: tools.Parameters{ 219 | tools.NewStringParameterWithRequired("collection", "Collection name to query", true), 220 | tools.NewStringParameterWithRequired("status", "Status to filter by", true), 221 | tools.NewIntParameterWithDefault("minPriority", 1, "Minimum priority level"), 222 | }, 223 | }, 224 | }, 225 | }, 226 | { 227 | desc: "with Firestore native JSON value types and template parameters", 228 | in: ` 229 | tools: 230 | query_with_typed_values: 231 | kind: firestore-query 232 | source: typed-firestore 233 | description: Query with Firestore native JSON value types 234 | collectionPath: "countries" 235 | filters: | 236 | { 237 | "or": [ 238 | {"field": "continent", "op": "==", "value": {"stringValue": "{{.continent}}"}}, 239 | { 240 | "and": [ 241 | {"field": "area", "op": ">", "value": {"integerValue": "2000000"}}, 242 | {"field": "area", "op": "<", "value": {"integerValue": "3000000"}}, 243 | {"field": "population", "op": ">=", "value": {"integerValue": "{{.minPopulation}}"}}, 244 | {"field": "gdp", "op": ">", "value": {"doubleValue": {{.minGdp}}}}, 245 | {"field": "isActive", "op": "==", "value": {"booleanValue": {{.isActive}}}}, 246 | {"field": "lastUpdated", "op": ">=", "value": {"timestampValue": "{{.startDate}}"}} 247 | ] 248 | } 249 | ] 250 | } 251 | parameters: 252 | - name: continent 253 | type: string 254 | description: Continent to filter by 255 | required: true 256 | - name: minPopulation 257 | type: string 258 | description: Minimum population as string 259 | required: true 260 | - name: minGdp 261 | type: float 262 | description: Minimum GDP value 263 | required: true 264 | - name: isActive 265 | type: boolean 266 | description: Filter by active status 267 | required: true 268 | - name: startDate 269 | type: string 270 | description: Start date in RFC3339 format 271 | required: true 272 | `, 273 | want: server.ToolConfigs{ 274 | "query_with_typed_values": firestorequery.Config{ 275 | Name: "query_with_typed_values", 276 | Kind: "firestore-query", 277 | Source: "typed-firestore", 278 | Description: "Query with Firestore native JSON value types", 279 | CollectionPath: "countries", 280 | Filters: `{ 281 | "or": [ 282 | {"field": "continent", "op": "==", "value": {"stringValue": "{{.continent}}"}}, 283 | { 284 | "and": [ 285 | {"field": "area", "op": ">", "value": {"integerValue": "2000000"}}, 286 | {"field": "area", "op": "<", "value": {"integerValue": "3000000"}}, 287 | {"field": "population", "op": ">=", "value": {"integerValue": "{{.minPopulation}}"}}, 288 | {"field": "gdp", "op": ">", "value": {"doubleValue": {{.minGdp}}}}, 289 | {"field": "isActive", "op": "==", "value": {"booleanValue": {{.isActive}}}}, 290 | {"field": "lastUpdated", "op": ">=", "value": {"timestampValue": "{{.startDate}}"}} 291 | ] 292 | } 293 | ] 294 | } 295 | `, 296 | AuthRequired: []string{}, 297 | Parameters: tools.Parameters{ 298 | tools.NewStringParameterWithRequired("continent", "Continent to filter by", true), 299 | tools.NewStringParameterWithRequired("minPopulation", "Minimum population as string", true), 300 | tools.NewFloatParameterWithRequired("minGdp", "Minimum GDP value", true), 301 | tools.NewBooleanParameterWithRequired("isActive", "Filter by active status", true), 302 | tools.NewStringParameterWithRequired("startDate", "Start date in RFC3339 format", true), 303 | }, 304 | }, 305 | }, 306 | }, 307 | } 308 | for _, tc := range tcs { 309 | t.Run(tc.desc, func(t *testing.T) { 310 | got := struct { 311 | Tools server.ToolConfigs `yaml:"tools"` 312 | }{} 313 | // Parse contents 314 | err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) 315 | if err != nil { 316 | t.Fatalf("unable to unmarshal: %s", err) 317 | } 318 | if diff := cmp.Diff(tc.want, got.Tools); diff != "" { 319 | t.Fatalf("incorrect parse: diff %v", diff) 320 | } 321 | }) 322 | } 323 | } 324 | 325 | func TestParseFromYamlMultipleQueryTools(t *testing.T) { 326 | ctx, err := testutils.ContextWithNewLogger() 327 | if err != nil { 328 | t.Fatalf("unexpected error: %s", err) 329 | } 330 | in := ` 331 | tools: 332 | query_user_posts: 333 | kind: firestore-query 334 | source: social-firestore 335 | description: Query user posts with filtering 336 | collectionPath: "users/{{.userId}}/posts" 337 | filters: | 338 | { 339 | "and": [ 340 | {"field": "visibility", "op": "==", "value": {"stringValue": "{{.visibility}}"}}, 341 | {"field": "createdAt", "op": ">=", "value": {"timestampValue": "{{.startDate}}"}} 342 | ] 343 | } 344 | select: 345 | - title 346 | - content 347 | - likes 348 | orderBy: 349 | field: createdAt 350 | direction: "{{.sortOrder}}" 351 | limit: 20 352 | parameters: 353 | - name: userId 354 | type: string 355 | description: User ID whose posts to query 356 | required: true 357 | - name: visibility 358 | type: string 359 | description: Post visibility (public, private, friends) 360 | required: true 361 | - name: startDate 362 | type: string 363 | description: Start date for posts 364 | required: true 365 | - name: sortOrder 366 | type: string 367 | description: Sort order (ASCENDING or DESCENDING) 368 | default: "DESCENDING" 369 | query_inventory: 370 | kind: firestore-query 371 | source: inventory-firestore 372 | description: Query inventory items 373 | collectionPath: "warehouses/{{.warehouseId}}/inventory" 374 | filters: | 375 | { 376 | "field": "quantity", "op": "<", "value": {"integerValue": "{{.threshold}}"}} 377 | parameters: 378 | - name: warehouseId 379 | type: string 380 | description: Warehouse ID to check inventory 381 | required: true 382 | - name: threshold 383 | type: integer 384 | description: Quantity threshold for low stock 385 | required: true 386 | query_transactions: 387 | kind: firestore-query 388 | source: finance-firestore 389 | description: Query financial transactions 390 | collectionPath: "accounts/{{.accountId}}/transactions" 391 | filters: | 392 | { 393 | "or": [ 394 | {"field": "type", "op": "==", "value": {"stringValue": "{{.transactionType}}"}}, 395 | {"field": "amount", "op": ">", "value": {"doubleValue": {{.minAmount}}}} 396 | ] 397 | } 398 | analyzeQuery: true 399 | authRequired: 400 | - finance-auth 401 | parameters: 402 | - name: accountId 403 | type: string 404 | description: Account ID for transactions 405 | required: true 406 | - name: transactionType 407 | type: string 408 | description: Type of transaction 409 | default: "all" 410 | - name: minAmount 411 | type: float 412 | description: Minimum transaction amount 413 | default: 0 414 | ` 415 | want := server.ToolConfigs{ 416 | "query_user_posts": firestorequery.Config{ 417 | Name: "query_user_posts", 418 | Kind: "firestore-query", 419 | Source: "social-firestore", 420 | Description: "Query user posts with filtering", 421 | CollectionPath: "users/{{.userId}}/posts", 422 | Filters: `{ 423 | "and": [ 424 | {"field": "visibility", "op": "==", "value": {"stringValue": "{{.visibility}}"}}, 425 | {"field": "createdAt", "op": ">=", "value": {"timestampValue": "{{.startDate}}"}} 426 | ] 427 | } 428 | `, 429 | Select: []string{"title", "content", "likes"}, 430 | OrderBy: map[string]any{ 431 | "field": "createdAt", 432 | "direction": "{{.sortOrder}}", 433 | }, 434 | Limit: "20", 435 | AuthRequired: []string{}, 436 | Parameters: tools.Parameters{ 437 | tools.NewStringParameterWithRequired("userId", "User ID whose posts to query", true), 438 | tools.NewStringParameterWithRequired("visibility", "Post visibility (public, private, friends)", true), 439 | tools.NewStringParameterWithRequired("startDate", "Start date for posts", true), 440 | tools.NewStringParameterWithDefault("sortOrder", "DESCENDING", "Sort order (ASCENDING or DESCENDING)"), 441 | }, 442 | }, 443 | "query_inventory": firestorequery.Config{ 444 | Name: "query_inventory", 445 | Kind: "firestore-query", 446 | Source: "inventory-firestore", 447 | Description: "Query inventory items", 448 | CollectionPath: "warehouses/{{.warehouseId}}/inventory", 449 | Filters: `{ 450 | "field": "quantity", "op": "<", "value": {"integerValue": "{{.threshold}}"}} 451 | `, 452 | AuthRequired: []string{}, 453 | Parameters: tools.Parameters{ 454 | tools.NewStringParameterWithRequired("warehouseId", "Warehouse ID to check inventory", true), 455 | tools.NewIntParameterWithRequired("threshold", "Quantity threshold for low stock", true), 456 | }, 457 | }, 458 | "query_transactions": firestorequery.Config{ 459 | Name: "query_transactions", 460 | Kind: "firestore-query", 461 | Source: "finance-firestore", 462 | Description: "Query financial transactions", 463 | CollectionPath: "accounts/{{.accountId}}/transactions", 464 | Filters: `{ 465 | "or": [ 466 | {"field": "type", "op": "==", "value": {"stringValue": "{{.transactionType}}"}}, 467 | {"field": "amount", "op": ">", "value": {"doubleValue": {{.minAmount}}}} 468 | ] 469 | } 470 | `, 471 | AnalyzeQuery: true, 472 | AuthRequired: []string{"finance-auth"}, 473 | Parameters: tools.Parameters{ 474 | tools.NewStringParameterWithRequired("accountId", "Account ID for transactions", true), 475 | tools.NewStringParameterWithDefault("transactionType", "all", "Type of transaction"), 476 | tools.NewFloatParameterWithDefault("minAmount", 0, "Minimum transaction amount"), 477 | }, 478 | }, 479 | } 480 | 481 | got := struct { 482 | Tools server.ToolConfigs `yaml:"tools"` 483 | }{} 484 | // Parse contents 485 | err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(in), &got) 486 | if err != nil { 487 | t.Fatalf("unable to unmarshal: %s", err) 488 | } 489 | if diff := cmp.Diff(want, got.Tools); diff != "" { 490 | t.Fatalf("incorrect parse: diff %v", diff) 491 | } 492 | } 493 | ``` -------------------------------------------------------------------------------- /tests/neo4j/neo4j_integration_test.go: -------------------------------------------------------------------------------- ```go 1 | // Copyright 2024 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package neo4j 16 | 17 | import ( 18 | "bytes" 19 | "context" 20 | "encoding/json" 21 | "io" 22 | "net/http" 23 | "os" 24 | "reflect" 25 | "regexp" 26 | "strings" 27 | "testing" 28 | "time" 29 | 30 | "github.com/neo4j/neo4j-go-driver/v5/neo4j" 31 | 32 | "github.com/googleapis/genai-toolbox/internal/testutils" 33 | "github.com/googleapis/genai-toolbox/tests" 34 | ) 35 | 36 | var ( 37 | Neo4jSourceKind = "neo4j" 38 | Neo4jDatabase = os.Getenv("NEO4J_DATABASE") 39 | Neo4jUri = os.Getenv("NEO4J_URI") 40 | Neo4jUser = os.Getenv("NEO4J_USER") 41 | Neo4jPass = os.Getenv("NEO4J_PASS") 42 | ) 43 | 44 | // getNeo4jVars retrieves necessary Neo4j connection details from environment variables. 45 | // It fails the test if any required variable is not set. 46 | func getNeo4jVars(t *testing.T) map[string]any { 47 | switch "" { 48 | case Neo4jDatabase: 49 | t.Fatal("'NEO4J_DATABASE' not set") 50 | case Neo4jUri: 51 | t.Fatal("'NEO4J_URI' not set") 52 | case Neo4jUser: 53 | t.Fatal("'NEO4J_USER' not set") 54 | case Neo4jPass: 55 | t.Fatal("'NEO4J_PASS' not set") 56 | } 57 | 58 | return map[string]any{ 59 | "kind": Neo4jSourceKind, 60 | "uri": Neo4jUri, 61 | "database": Neo4jDatabase, 62 | "user": Neo4jUser, 63 | "password": Neo4jPass, 64 | } 65 | } 66 | 67 | // TestNeo4jToolEndpoints sets up an integration test server and tests the API endpoints 68 | // for various Neo4j tools, including cypher execution and schema retrieval. 69 | func TestNeo4jToolEndpoints(t *testing.T) { 70 | sourceConfig := getNeo4jVars(t) 71 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 72 | defer cancel() 73 | 74 | var args []string 75 | 76 | // Write config into a file and pass it to the command. 77 | // This configuration defines the data source and the tools to be tested. 78 | toolsFile := map[string]any{ 79 | "sources": map[string]any{ 80 | "my-neo4j-instance": sourceConfig, 81 | }, 82 | "tools": map[string]any{ 83 | "my-simple-cypher-tool": map[string]any{ 84 | "kind": "neo4j-cypher", 85 | "source": "my-neo4j-instance", 86 | "description": "Simple tool to test end to end functionality.", 87 | "statement": "RETURN 1 as a;", 88 | }, 89 | "my-simple-execute-cypher-tool": map[string]any{ 90 | "kind": "neo4j-execute-cypher", 91 | "source": "my-neo4j-instance", 92 | "description": "Simple tool to test end to end functionality.", 93 | }, 94 | "my-readonly-execute-cypher-tool": map[string]any{ 95 | "kind": "neo4j-execute-cypher", 96 | "source": "my-neo4j-instance", 97 | "description": "A readonly cypher execution tool.", 98 | "readOnly": true, 99 | }, 100 | "my-schema-tool": map[string]any{ 101 | "kind": "neo4j-schema", 102 | "source": "my-neo4j-instance", 103 | "description": "A tool to get the Neo4j schema.", 104 | }, 105 | "my-schema-tool-with-cache": map[string]any{ 106 | "kind": "neo4j-schema", 107 | "source": "my-neo4j-instance", 108 | "description": "A schema tool with a custom cache expiration.", 109 | "cacheExpireMinutes": 10, 110 | }, 111 | "my-populated-schema-tool": map[string]any{ 112 | "kind": "neo4j-schema", 113 | "source": "my-neo4j-instance", 114 | "description": "A tool to get the Neo4j schema from a populated DB.", 115 | }, 116 | }, 117 | } 118 | cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) 119 | if err != nil { 120 | t.Fatalf("command initialization returned an error: %s", err) 121 | } 122 | defer cleanup() 123 | 124 | waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 125 | defer cancel() 126 | out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) 127 | if err != nil { 128 | t.Logf("toolbox command logs: \n%s", out) 129 | t.Fatalf("toolbox didn't start successfully: %s", err) 130 | } 131 | 132 | // Test tool `GET` endpoints to verify their manifests are correct. 133 | tcs := []struct { 134 | name string 135 | api string 136 | want map[string]any 137 | }{ 138 | { 139 | name: "get my-simple-cypher-tool", 140 | api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/", 141 | want: map[string]any{ 142 | "my-simple-cypher-tool": map[string]any{ 143 | "description": "Simple tool to test end to end functionality.", 144 | "parameters": []any{}, 145 | "authRequired": []any{}, 146 | }, 147 | }, 148 | }, 149 | { 150 | name: "get my-simple-execute-cypher-tool", 151 | api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/", 152 | want: map[string]any{ 153 | "my-simple-execute-cypher-tool": map[string]any{ 154 | "description": "Simple tool to test end to end functionality.", 155 | "parameters": []any{ 156 | map[string]any{ 157 | "name": "cypher", 158 | "type": "string", 159 | "required": true, 160 | "description": "The cypher to execute.", 161 | "authSources": []any{}, 162 | }, 163 | }, 164 | "authRequired": []any{}, 165 | }, 166 | }, 167 | }, 168 | { 169 | name: "get my-schema-tool", 170 | api: "http://127.0.0.1:5000/api/tool/my-schema-tool/", 171 | want: map[string]any{ 172 | "my-schema-tool": map[string]any{ 173 | "description": "A tool to get the Neo4j schema.", 174 | "parameters": []any{}, 175 | "authRequired": []any{}, 176 | }, 177 | }, 178 | }, 179 | { 180 | name: "get my-schema-tool-with-cache", 181 | api: "http://127.0.0.1:5000/api/tool/my-schema-tool-with-cache/", 182 | want: map[string]any{ 183 | "my-schema-tool-with-cache": map[string]any{ 184 | "description": "A schema tool with a custom cache expiration.", 185 | "parameters": []any{}, 186 | "authRequired": []any{}, 187 | }, 188 | }, 189 | }, 190 | } 191 | for _, tc := range tcs { 192 | t.Run(tc.name, func(t *testing.T) { 193 | resp, err := http.Get(tc.api) 194 | if err != nil { 195 | t.Fatalf("error when sending a request: %s", err) 196 | } 197 | defer resp.Body.Close() 198 | if resp.StatusCode != 200 { 199 | t.Fatalf("response status code is not 200") 200 | } 201 | 202 | var body map[string]interface{} 203 | err = json.NewDecoder(resp.Body).Decode(&body) 204 | if err != nil { 205 | t.Fatalf("error parsing response body") 206 | } 207 | 208 | got, ok := body["tools"] 209 | if !ok { 210 | t.Fatalf("unable to find tools in response body") 211 | } 212 | if !reflect.DeepEqual(got, tc.want) { 213 | t.Fatalf("got %q, want %q", got, tc.want) 214 | } 215 | }) 216 | } 217 | 218 | // Test tool `invoke` endpoints to verify their functionality. 219 | invokeTcs := []struct { 220 | name string 221 | api string 222 | requestBody io.Reader 223 | want string 224 | wantStatus int 225 | wantErrorSubstring string 226 | prepareData func(t *testing.T) 227 | validateFunc func(t *testing.T, body string) 228 | }{ 229 | { 230 | name: "invoke my-simple-cypher-tool", 231 | api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/invoke", 232 | requestBody: bytes.NewBuffer([]byte(`{}`)), 233 | want: "[{\"a\":1}]", 234 | wantStatus: http.StatusOK, 235 | }, 236 | { 237 | name: "invoke my-simple-execute-cypher-tool", 238 | api: "http://127.0.0.1:5000/api/tool/my-simple-execute-cypher-tool/invoke", 239 | requestBody: bytes.NewBuffer([]byte(`{"cypher": "RETURN 1 as a;"}`)), 240 | want: "[{\"a\":1}]", 241 | wantStatus: http.StatusOK, 242 | }, 243 | { 244 | name: "invoke readonly tool with write query", 245 | api: "http://127.0.0.1:5000/api/tool/my-readonly-execute-cypher-tool/invoke", 246 | requestBody: bytes.NewBuffer([]byte(`{"cypher": "CREATE (n:TestNode)"}`)), 247 | wantStatus: http.StatusBadRequest, 248 | wantErrorSubstring: "this tool is read-only and cannot execute write queries", 249 | }, 250 | { 251 | name: "invoke my-schema-tool", 252 | api: "http://127.0.0.1:5000/api/tool/my-schema-tool/invoke", 253 | requestBody: bytes.NewBuffer([]byte(`{}`)), 254 | wantStatus: http.StatusOK, 255 | validateFunc: func(t *testing.T, body string) { 256 | var result map[string]any 257 | if err := json.Unmarshal([]byte(body), &result); err != nil { 258 | t.Fatalf("failed to unmarshal schema result: %v", err) 259 | } 260 | // Check for the presence of top-level keys in the schema response. 261 | expectedKeys := []string{"nodeLabels", "relationships", "constraints", "indexes", "databaseInfo", "statistics"} 262 | for _, key := range expectedKeys { 263 | if _, ok := result[key]; !ok { 264 | t.Errorf("expected key %q not found in schema response", key) 265 | } 266 | } 267 | }, 268 | }, 269 | { 270 | name: "invoke my-schema-tool-with-cache", 271 | api: "http://127.0.0.1:5000/api/tool/my-schema-tool-with-cache/invoke", 272 | requestBody: bytes.NewBuffer([]byte(`{}`)), 273 | wantStatus: http.StatusOK, 274 | validateFunc: func(t *testing.T, body string) { 275 | var result map[string]any 276 | if err := json.Unmarshal([]byte(body), &result); err != nil { 277 | t.Fatalf("failed to unmarshal schema result: %v", err) 278 | } 279 | // Also check the structure of the schema response for the cached tool. 280 | expectedKeys := []string{"nodeLabels", "relationships", "constraints", "indexes", "databaseInfo", "statistics"} 281 | for _, key := range expectedKeys { 282 | if _, ok := result[key]; !ok { 283 | t.Errorf("expected key %q not found in schema response", key) 284 | } 285 | } 286 | }, 287 | }, 288 | { 289 | name: "invoke my-schema-tool with populated data", 290 | api: "http://127.0.0.1:5000/api/tool/my-populated-schema-tool/invoke", 291 | requestBody: bytes.NewBuffer([]byte(`{}`)), 292 | wantStatus: http.StatusOK, 293 | prepareData: func(t *testing.T) { 294 | ctx := context.Background() 295 | driver, err := neo4j.NewDriverWithContext(Neo4jUri, neo4j.BasicAuth(Neo4jUser, Neo4jPass, "")) 296 | if err != nil { 297 | t.Fatalf("failed to create neo4j driver: %v", err) 298 | } 299 | 300 | // Helper to execute queries for setup and teardown. 301 | execute := func(query string) { 302 | session := driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: Neo4jDatabase}) 303 | defer session.Close(ctx) 304 | // Use ExecuteWrite to ensure the query is committed before proceeding. 305 | _, err := session.ExecuteWrite(ctx, func(tx neo4j.ManagedTransaction) (any, error) { 306 | _, err := tx.Run(ctx, query, nil) 307 | return nil, err 308 | }) 309 | 310 | // Don't fail the test on teardown errors (e.g., entity doesn't exist). 311 | if err != nil && !strings.Contains(query, "DROP") { 312 | t.Fatalf("query failed: %s\nerror: %v", query, err) 313 | } 314 | } 315 | 316 | // Teardown logic is deferred to ensure it runs even if the test fails. 317 | // The driver will be closed at the end of this block. 318 | t.Cleanup(func() { 319 | execute("DROP CONSTRAINT PersonNameUnique IF EXISTS") 320 | execute("DROP INDEX MovieTitleIndex IF EXISTS") 321 | execute("MATCH (n) DETACH DELETE n") 322 | if err := driver.Close(ctx); err != nil { 323 | t.Errorf("failed to close driver during cleanup: %v", err) 324 | } 325 | }) 326 | 327 | // Setup: Create constraints, indexes, and data. 328 | execute("MERGE (p:Person {name: 'Alice'}) MERGE (m:Movie {title: 'The Matrix'}) MERGE (p)-[:ACTED_IN]->(m)") 329 | execute("CREATE CONSTRAINT PersonNameUnique IF NOT EXISTS FOR (p:Person) REQUIRE p.name IS UNIQUE") 330 | execute("CREATE INDEX MovieTitleIndex IF NOT EXISTS FOR (m:Movie) ON (m.title)") 331 | }, 332 | validateFunc: func(t *testing.T, body string) { 333 | // Define structs for unmarshaling the detailed schema. 334 | type Property struct { 335 | Name string `json:"name"` 336 | Types []string `json:"types"` 337 | } 338 | type NodeLabel struct { 339 | Name string `json:"name"` 340 | Properties []Property `json:"properties"` 341 | } 342 | type Relationship struct { 343 | Type string `json:"type"` 344 | StartNode string `json:"startNode"` 345 | EndNode string `json:"endNode"` 346 | } 347 | type Constraint struct { 348 | Name string `json:"name"` 349 | Label string `json:"label"` 350 | Properties []string `json:"properties"` 351 | } 352 | type Index struct { 353 | Name string `json:"name"` 354 | Label string `json:"label"` 355 | Properties []string `json:"properties"` 356 | } 357 | type Schema struct { 358 | NodeLabels []NodeLabel `json:"nodeLabels"` 359 | Relationships []Relationship `json:"relationships"` 360 | Constraints []Constraint `json:"constraints"` 361 | Indexes []Index `json:"indexes"` 362 | } 363 | 364 | var schema Schema 365 | if err := json.Unmarshal([]byte(body), &schema); err != nil { 366 | t.Fatalf("failed to unmarshal schema json: %v\nResponse body: %s", err, body) 367 | } 368 | 369 | // --- Validate Node Labels and Properties --- 370 | var personLabelFound, movieLabelFound bool 371 | for _, l := range schema.NodeLabels { 372 | if l.Name == "Person" { 373 | personLabelFound = true 374 | propFound := false 375 | for _, p := range l.Properties { 376 | if p.Name == "name" { 377 | propFound = true 378 | break 379 | } 380 | } 381 | if !propFound { 382 | t.Errorf("expected Person label to have 'name' property, but it was not found") 383 | } 384 | } 385 | if l.Name == "Movie" { 386 | movieLabelFound = true 387 | propFound := false 388 | for _, p := range l.Properties { 389 | if p.Name == "title" { 390 | propFound = true 391 | break 392 | } 393 | } 394 | if !propFound { 395 | t.Errorf("expected Movie label to have 'title' property, but it was not found") 396 | } 397 | } 398 | } 399 | if !personLabelFound { 400 | t.Error("expected to find 'Person' in nodeLabels") 401 | } 402 | if !movieLabelFound { 403 | t.Error("expected to find 'Movie' in nodeLabels") 404 | } 405 | 406 | // --- Validate Relationships --- 407 | relFound := false 408 | for _, r := range schema.Relationships { 409 | if r.Type == "ACTED_IN" && r.StartNode == "Person" && r.EndNode == "Movie" { 410 | relFound = true 411 | break 412 | } 413 | } 414 | if !relFound { 415 | t.Errorf("expected to find relationship '(:Person)-[:ACTED_IN]->(:Movie)', but it was not found") 416 | } 417 | 418 | // --- Validate Constraints --- 419 | constraintFound := false 420 | for _, c := range schema.Constraints { 421 | if c.Name == "PersonNameUnique" && c.Label == "Person" { 422 | propFound := false 423 | for _, p := range c.Properties { 424 | if p == "name" { 425 | propFound = true 426 | break 427 | } 428 | } 429 | if propFound { 430 | constraintFound = true 431 | break 432 | } 433 | } 434 | } 435 | if !constraintFound { 436 | t.Errorf("expected to find constraint 'PersonNameUnique' on Person(name), but it was not found") 437 | } 438 | 439 | // --- Validate Indexes --- 440 | indexFound := false 441 | for _, i := range schema.Indexes { 442 | if i.Name == "MovieTitleIndex" && i.Label == "Movie" { 443 | propFound := false 444 | for _, p := range i.Properties { 445 | if p == "title" { 446 | propFound = true 447 | break 448 | } 449 | } 450 | if propFound { 451 | indexFound = true 452 | break 453 | } 454 | } 455 | } 456 | if !indexFound { 457 | t.Errorf("expected to find index 'MovieTitleIndex' on Movie(title), but it was not found") 458 | } 459 | }, 460 | }, 461 | } 462 | for _, tc := range invokeTcs { 463 | t.Run(tc.name, func(t *testing.T) { 464 | // Prepare data if a preparation function is provided. 465 | if tc.prepareData != nil { 466 | tc.prepareData(t) 467 | } 468 | 469 | resp, err := http.Post(tc.api, "application/json", tc.requestBody) 470 | if err != nil { 471 | t.Fatalf("error when sending a request: %s", err) 472 | } 473 | defer resp.Body.Close() 474 | if resp.StatusCode != tc.wantStatus { 475 | bodyBytes, _ := io.ReadAll(resp.Body) 476 | t.Fatalf("response status code: got %d, want %d: %s", resp.StatusCode, tc.wantStatus, string(bodyBytes)) 477 | } 478 | 479 | if tc.wantStatus == http.StatusOK { 480 | var body map[string]interface{} 481 | err = json.NewDecoder(resp.Body).Decode(&body) 482 | if err != nil { 483 | t.Fatalf("error parsing response body") 484 | } 485 | got, ok := body["result"].(string) 486 | if !ok { 487 | t.Fatalf("unable to find result in response body") 488 | } 489 | 490 | if tc.validateFunc != nil { 491 | // Use the custom validation function if provided. 492 | tc.validateFunc(t, got) 493 | } else if got != tc.want { 494 | // Otherwise, perform a direct string comparison. 495 | t.Fatalf("unexpected value: got %q, want %q", got, tc.want) 496 | } 497 | } else { 498 | bodyBytes, err := io.ReadAll(resp.Body) 499 | if err != nil { 500 | t.Fatalf("failed to read error response body: %s", err) 501 | } 502 | bodyString := string(bodyBytes) 503 | if !strings.Contains(bodyString, tc.wantErrorSubstring) { 504 | t.Fatalf("response body %q does not contain expected error %q", bodyString, tc.wantErrorSubstring) 505 | } 506 | } 507 | }) 508 | } 509 | } 510 | ```