# Directory Structure
```
├── .dockerignore
├── .github
│ ├── CODEOWNERS
│ ├── dependabot.yml
│ └── workflows
│ ├── contributors-list.yml
│ ├── test.yml
│ ├── trunk_check.yml
│ └── trunk_upgrade.yml
├── .gitignore
├── .python-version
├── .trunk
│ ├── .gitignore
│ ├── configs
│ │ ├── .checkov.yml
│ │ ├── .isort.cfg
│ │ ├── .markdownlint.yaml
│ │ ├── .shellcheckrc
│ │ ├── .yamllint.yaml
│ │ └── ruff.toml
│ └── trunk.yaml
├── config.yml.template
├── dev
│ ├── build.sh
│ ├── clean.sh
│ ├── publish.sh
│ ├── setup.sh
│ └── test_python.sh
├── Dockerfile
├── docs
│ └── img
│ └── archirecture.png
├── LICENSE
├── Makefile
├── pyproject.toml
├── README.md
├── requirements.setup.txt
├── src
│ ├── mcp_vertexai_search
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── agent.py
│ │ ├── cli.py
│ │ ├── config.py
│ │ ├── google_cloud.py
│ │ ├── server.py
│ │ └── utils.py
│ └── research_agent
│ ├── __init__.py
│ ├── chat.py
│ ├── mcp_client.py
│ └── utils.py
├── tests
│ ├── __init__.py
│ ├── test_config.py
│ └── test_utils.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
```
3.12
```
--------------------------------------------------------------------------------
/.trunk/configs/.isort.cfg:
--------------------------------------------------------------------------------
```
[settings]
profile=black
```
--------------------------------------------------------------------------------
/.trunk/configs/.checkov.yml:
--------------------------------------------------------------------------------
```yaml
skip-check:
- CKV2_GHA_1
```
--------------------------------------------------------------------------------
/.trunk/.gitignore:
--------------------------------------------------------------------------------
```
*out
*logs
*actions
*notifications
*tools
plugins
user_trunk.yaml
user.yaml
tmp
```
--------------------------------------------------------------------------------
/.trunk/configs/.markdownlint.yaml:
--------------------------------------------------------------------------------
```yaml
# Prettier friendly markdownlint config (all formatting rules disabled)
extends: markdownlint/style/prettier
```
--------------------------------------------------------------------------------
/.trunk/configs/.yamllint.yaml:
--------------------------------------------------------------------------------
```yaml
rules:
quoted-strings:
required: only-when-needed
extra-allowed: ["{|}"]
key-duplicates: {}
octal-values:
forbid-implicit-octal: true
```
--------------------------------------------------------------------------------
/.trunk/configs/.shellcheckrc:
--------------------------------------------------------------------------------
```
enable=all
source-path=SCRIPTDIR
disable=SC2154
# If you're having issues with shellcheck following source, disable the errors via:
# disable=SC1090
# disable=SC1091
```
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
```
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc
```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc
# server config
config.yml
```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
# MCP Server for Vertex AI Search
This is a MCP server to search documents using Vertex AI.
## Architecture
This solution uses Gemini with Vertex AI grounding to search documents using your private data.
Grounding improves the quality of search results by grounding Gemini's responses in your data stored in Vertex AI Datastore.
We can integrate one or multiple Vertex AI data stores to the MCP server.
For more details on grounding, refer to [Vertex AI Grounding Documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/ground-with-your-data).

## How to use
There are two ways to use this MCP server.
If you want to run this on Docker, the first approach would be good as Dockerfile is provided in the project.
### 1. Clone the repository
```shell
# Clone the repository
git clone [email protected]:ubie-oss/mcp-vertexai-search.git
# Create a virtual environment
uv venv
# Install the dependencies
uv sync --all-extras
# Check the command
uv run mcp-vertexai-search
```
### Install the python package
The package isn't published to PyPI yet, but we can install it from the repository.
We need a config file derives from [config.yml.template](./config.yml.template) to run the MCP server, because the python package doesn't include the config template.
Please refer to [Appendix A: Config file](#appendix-a-config-file) for the details of the config file.
```shell
# Install the package
pip install git+https://github.com/ubie-oss/mcp-vertexai-search.git
# Check the command
mcp-vertexai-search --help
```
## Development
### Prerequisites
- [uv](https://docs.astral.sh/uv/getting-started/installation/)
- Vertex AI data store
- Please look into [the official documentation about data stores](https://cloud.google.com/generative-ai-app-builder/docs/create-datastore-ingest) for more information
### Set up Local Environment
```shell
# Optional: Install uv
python -m pip install -r requirements.setup.txt
# Create a virtual environment
uv venv
uv sync --all-extras
```
### Run the MCP server
This supports two transports for SSE (Server-Sent Events) and stdio (Standard Input Output).
We can control the transport by setting the `--transport` flag.
We can configure the MCP server with a YAML file.
[config.yml.template](./config.yml.template) is a template for the config file.
Please modify the config file to fit your needs.
```bash
uv run mcp-vertexai-search serve \
--config config.yml \
--transport <stdio|sse>
```
### Test the Vertex AI Search
We can test the Vertex AI Search by using the `mcp-vertexai-search search` command without the MCP server.
```bash
uv run mcp-vertexai-search search \
--config config.yml \
--query <your-query>
```
## Appendix A: Config file
[config.yml.template](./config.yml.template) is a template for the config file.
- `server`
- `server.name`: The name of the MCP server
- `model`
- `model.model_name`: The name of the Vertex AI model
- `model.project_id`: The project ID of the Vertex AI model
- `model.location`: The location of the model (e.g. us-central1)
- `model.impersonate_service_account`: The service account to impersonate
- `model.generate_content_config`: The configuration for the generate content API
- `data_stores`: The list of Vertex AI data stores
- `data_stores.project_id`: The project ID of the Vertex AI data store
- `data_stores.location`: The location of the Vertex AI data store (e.g. us)
- `data_stores.datastore_id`: The ID of the Vertex AI data store
- `data_stores.tool_name`: The name of the tool
- `data_stores.description`: The description of the Vertex AI data store
```
--------------------------------------------------------------------------------
/src/mcp_vertexai_search/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/src/research_agent/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/requirements.setup.txt:
--------------------------------------------------------------------------------
```
uv>=0.6
```
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
```python
__version__ = "0.0.1"
```
--------------------------------------------------------------------------------
/dev/build.sh:
--------------------------------------------------------------------------------
```bash
#!/bin/bash
set -Eo pipefail
uv build
```
--------------------------------------------------------------------------------
/src/mcp_vertexai_search/__main__.py:
--------------------------------------------------------------------------------
```python
from mcp_vertexai_search.cli import serve
serve()
```
--------------------------------------------------------------------------------
/.trunk/configs/ruff.toml:
--------------------------------------------------------------------------------
```toml
# Generic, formatter-friendly config.
select = ["B", "D3", "E", "F"]
# Never enforce `E501` (line length violations). This should be handled by formatters.
ignore = ["E501"]
```
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
```dockerfile
FROM python:3.12-slim
WORKDIR /app
COPY requirements.setup.txt pyproject.toml uv.lock /app/
RUN python -m pip install --no-cache-dir -r requirements.setup.txt \
&& uv venv \
&& uv sync
COPY . /app
ENTRYPOINT ["uv", "run", "mcp-vertexai-search"]
```
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
```python
import unittest
from mcp_vertexai_search.utils import to_mcp_tool
class TestUtils(unittest.TestCase):
def test_to_mcp_tool(self):
tool = to_mcp_tool("test-tool", "test-description")
self.assertEqual(tool.name, "test-tool")
self.assertEqual(tool.description, "test-description")
```
--------------------------------------------------------------------------------
/.github/workflows/contributors-list.yml:
--------------------------------------------------------------------------------
```yaml
name: Generate contributors list
on:
push:
branches:
- main
jobs:
# SEE https://github.com/marketplace/actions/contribute-list
contrib-readme-job:
runs-on: ubuntu-latest
name: A job to automate contrib in readme
steps:
- name: Contribute List
uses: akhilmhdh/[email protected]
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
```
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
```yaml
# See GitHub's documentation for more information on this file:
# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly
# - package-ecosystem: pip
# directory: /
# schedule:
# interval: weekly
```
--------------------------------------------------------------------------------
/.github/workflows/trunk_check.yml:
--------------------------------------------------------------------------------
```yaml
name: Trunk Check
on:
pull_request:
workflow_dispatch:
concurrency:
group: ${{ github.head_ref || github.run_id }}
cancel-in-progress: true
permissions: read-all
jobs:
trunk_check:
name: Trunk Check Runner
runs-on: ubuntu-latest
permissions:
checks: write # For trunk to post annotations
contents: read # For repo checkout
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Trunk Check
uses: trunk-io/trunk-action@v1
```
--------------------------------------------------------------------------------
/dev/publish.sh:
--------------------------------------------------------------------------------
```bash
set -Eo pipefail
set -x
# Constants
SCRIPT_FILE="$(readlink -f "$0")"
SCRIPT_DIR="$(dirname "${SCRIPT_FILE}")"
MODULE_DIR="$(dirname "${SCRIPT_DIR}")"
cd "${MODULE_DIR}" || exit
# Arguments
target=${1:?"target is not set"}
# Ensure uv is installed
pip install uv
# Build the package first
uv build
# Publish to the specified target
if [[ ${target} == "pypi" ]]; then
uv publish
elif [[ ${target} == "testpypi" ]]; then
uv publish --publish-url "https://test.pypi.org/legacy/"
else
echo "No such target ${target}"
exit 1
fi
```
--------------------------------------------------------------------------------
/.github/workflows/trunk_upgrade.yml:
--------------------------------------------------------------------------------
```yaml
name: Upgrade Trunk
on:
workflow_dispatch: {}
schedule:
# Runs the first day of every month (in the UTC timezone)
- cron: 0 0 1 * *
permissions: read-all
jobs:
trunk_upgrade:
name: Upgrade Trunk
runs-on: ubuntu-latest
permissions:
contents: write # For trunk to create PRs
pull-requests: write # For trunk to create PRs
steps:
- name: Checkout
uses: actions/checkout@v5
# >>> Install your own deps here (npm install, etc) <<<
# SEE https://github.com/trunk-io/trunk-action
- name: Trunk Upgrade
uses: trunk-io/trunk-action/upgrade@v1
with:
signoff: true
```
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
```yaml
name: Test python
on:
pull_request:
paths:
- .github/workflows/test.yml
- pyproject.toml
- dbt_artifacts_parser/**/*.py
- tests/**/*.py
- pylintrc
push:
branches:
- main
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11", "3.12"]
fail-fast: false
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install -r requirements.setup.txt
bash dev/setup.sh --deps "development"
- name: Run tests
run: bash dev/test_python.sh
- name: Test build
run: |
bash dev/build.sh
```
--------------------------------------------------------------------------------
/dev/test_python.sh:
--------------------------------------------------------------------------------
```bash
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -Eeuo pipefail
# Constants
SCRIPT_FILE="$(readlink -f "$0")"
SCRIPT_DIR="$(dirname "${SCRIPT_FILE}")"
MODULE_DIR="$(dirname "${SCRIPT_DIR}")"
pytest -v -s --cache-clear "${MODULE_DIR}/tests"
```
--------------------------------------------------------------------------------
/src/research_agent/utils.py:
--------------------------------------------------------------------------------
```python
from google import genai
from google.genai import types as genai_types
from mcp import types as mcp_types
def to_gemini_tool(mcp_tool: mcp_types.Tool) -> genai_types.Tool:
"""
Converts an MCP tool schema to a Gemini tool.
Args:
name: The name of the tool.
description: The description of the tool.
input_schema: The input schema of the tool.
Returns:
A Gemini tool.
"""
required_params: list[str] = mcp_tool.inputSchema.get("required", [])
properties = {}
for key, value in mcp_tool.inputSchema.get("properties", {}).items():
schema_dict = {
"type": value.get("type", "STRING").upper(),
"description": value.get("description", ""),
}
properties[key] = genai_types.Schema(**schema_dict)
function = genai.types.FunctionDeclaration(
name=mcp_tool.name,
description=mcp_tool.description,
parameters=genai.types.Schema(
type="OBJECT",
properties=properties,
required=required_params,
),
)
return genai_types.Tool(function_declarations=[function])
```
--------------------------------------------------------------------------------
/dev/clean.sh:
--------------------------------------------------------------------------------
```bash
#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
set -x
# Constants
SCRIPT_FILE="$(readlink -f "$0")"
SCRIPT_DIR="$(dirname "${SCRIPT_FILE}")"
MODULE_DIR="$(dirname "${SCRIPT_DIR}")"
cleaned_dirs=(
dist
sdist
.pytest_cache
)
for cleaned_dir in "${cleaned_dirs[@]}"; do
if [[ -d "${MODULE_DIR}/${cleaned_dir}" ]]; then
rm -r "${MODULE_DIR:?}/${cleaned_dir}"
fi
done
```
--------------------------------------------------------------------------------
/.trunk/trunk.yaml:
--------------------------------------------------------------------------------
```yaml
# This file controls the behavior of Trunk: https://docs.trunk.io/cli
# To learn more about the format of this file, see https://docs.trunk.io/reference/trunk-yaml
version: 0.1
cli:
version: 1.24.0
# Trunk provides extensibility via plugins. (https://docs.trunk.io/plugins)
plugins:
sources:
- id: trunk
ref: v1.7.0
uri: https://github.com/trunk-io/plugins
# Many linters and tools depend on runtimes - configure them here. (https://docs.trunk.io/runtimes)
runtimes:
enabled:
- [email protected]
- [email protected]
- [email protected]
# This is the section where you manage your linters. (https://docs.trunk.io/check/configuration)
lint:
disabled:
- black
enabled:
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- [email protected]
- git-diff-check
- [email protected]
- [email protected]
- [email protected]
actions:
enabled:
- trunk-announce
- trunk-check-pre-push
- trunk-fmt-pre-commit
- trunk-upgrade-available
```
--------------------------------------------------------------------------------
/dev/setup.sh:
--------------------------------------------------------------------------------
```bash
# Constants
SCRIPT_FILE="$(readlink -f "$0")"
SCRIPT_DIR="$(dirname "${SCRIPT_FILE}")"
MODULE_DIR="$(dirname "${SCRIPT_DIR}")"
# Arguments
deps="production"
use_venv=false
while (($# > 0)); do
if [[ $1 == "--use-venv" ]]; then
use_venv=true
shift 1
elif [[ $1 == "--deps" ]]; then
if [[ $2 != "production" && $2 != "development" ]]; then
echo "Error: deps must be one of 'production' or 'development'"
exit 1
fi
deps="$2"
shift 2
else
echo "Unknown argument: $1"
exit 1
fi
done
# Change to the module directory
cd "${MODULE_DIR}"
# Install uv and dependencies
pip install --force-reinstall -r "${MODULE_DIR}/requirements.setup.txt"
UV_PIP_OPTIONS=("--force-reinstall")
if [[ ${use_venv} == true ]]; then
# Create virtual environment
uv venv
# Activate virtual environment
if [[ -f .venv/bin/activate ]]; then
# shellcheck disable=SC1091
source .venv/bin/activate
else
echo "Error: .venv/bin/activate not found"
exit 1
fi
else
UV_PIP_OPTIONS+=("--system")
fi
# Install package and dependencies
if [[ ${deps} == "production" ]]; then
uv pip install "${UV_PIP_OPTIONS[@]}" -e "."
else
uv pip install "${UV_PIP_OPTIONS[@]}" -e ".[dev,test]"
fi
```
--------------------------------------------------------------------------------
/src/mcp_vertexai_search/utils.py:
--------------------------------------------------------------------------------
```python
from typing import Dict, List
from mcp import types as mcp_types
from mcp_vertexai_search.config import DataStoreConfig
def to_mcp_tool(tool_name: str, description: str) -> mcp_types.Tool:
"""Convert a tool name and description to an MCP Tool"""
return mcp_types.Tool(
name=tool_name,
description=description,
inputSchema={
"type": "object",
"required": ["query"],
"properties": {
"query": {
"type": "string",
"description": """\
A natural language question, not search keywords, used to query the documents.
The query question should be sentence(s), not search keywords.
""".strip(),
},
},
},
)
def to_mcp_tools_map(
data_store_configs: List[DataStoreConfig],
) -> Dict[str, mcp_types.Tool]:
"""Convert a list of DataStoreConfigs to a tool map"""
return {
data_store_config.tool_name: to_mcp_tool(
data_store_config.tool_name, data_store_config.description
)
for data_store_config in data_store_configs
}
```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
```toml
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/mcp_vertexai_search"]
[project]
name = "mcp-vertexai-search"
version = "0.1.0"
authors = [{ name = "ubie" }]
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10.0"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Information Technology",
"Intended Audience :: System Administrators",
"Operating System :: OS Independent",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Typing :: Typed",
]
description = "A dbt artifacts parser in python"
dependencies = [
"click>=8.1.8",
"google-cloud-aiplatform>=1.96.0",
"google-cloud-discoveryengine>=0.13.8",
"loguru>=0.7.3",
"mcp[cli]>=1.9.2",
"pydantic>=2.10.6",
"pyyaml>=6.0.2",
"uvicorn>=0.34.0",
"vertexai>=1.43.0",
]
[project.optional-dependencies]
dev = [
"autopep8>=2.3.2",
"bandit>=1.8.3",
"black>=25.1.0",
"google-genai>=1.2.0",
"isort>=6.0.0",
"langgraph>=0.2.74",
"pytest>=8.3.4",
"ruff>=0.9.6",
]
[project.scripts]
mcp-vertexai-search = "mcp_vertexai_search.cli:cli"
```
--------------------------------------------------------------------------------
/src/research_agent/mcp_client.py:
--------------------------------------------------------------------------------
```python
from contextlib import AsyncExitStack
from typing import Optional
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
class MCPClient:
def __init__(self, name: str,server_url: Optional[str] = None):
# Initialize session and client objects
self.name = name
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
if server_url:
self.connect_to_server(server_url)
async def connect_to_server(self, server_url: str):
"""Connect to an MCP server running with SSE transport"""
# Use AsyncExitStack to manage the contexts
_sse_client = sse_client(url=server_url)
streams = await self.exit_stack.enter_async_context(_sse_client)
_session_context = ClientSession(*streams)
self.session: ClientSession = await self.exit_stack.enter_async_context(
_session_context
)
# Initialize
await self.session.initialize()
async def cleanup(self):
"""Properly clean up the session and streams"""
await self.exit_stack.aclose()
async def list_tools(self):
return await self.session.list_tools()
async def call_tool(self, tool_name: str, tool_arguments: Optional[dict] = None):
return await self.session.call_tool(tool_name, tool_arguments)
if __name__ == "__main__":
async def main():
client = MCPClient()
await client.connect_to_server(server_url="http://0.0.0.0:8080/sse")
tools = await client.list_tools()
print(tools)
tool_call = await client.call_tool("document-search", {"query": "cpp segment とはなんですか?"})
print(tool_call)
await client.cleanup() # Ensure cleanup is called
import asyncio
asyncio.run(main())
```
--------------------------------------------------------------------------------
/src/mcp_vertexai_search/google_cloud.py:
--------------------------------------------------------------------------------
```python
from typing import List, Optional
from google import auth
from google.auth import impersonated_credentials
def get_credentials(
project_id: Optional[str] = None,
impersonate_service_account: Optional[str] = None,
scopes: Optional[List[str]] = None,
lifetime: Optional[int] = None,
) -> auth.credentials.Credentials:
"""Get the credentials"""
if impersonate_service_account is not None:
return get_impersonate_credentials(
impersonate_service_account, project_id, scopes, lifetime
)
return get_default_credentials(project_id)
def get_default_credentials(
project_id: Optional[str] = None,
) -> auth.credentials.Credentials:
"""Get the default credentials"""
if project_id is not None:
credentials, _ = auth.default(quota_project_id=project_id)
else:
credentials, _ = auth.default()
return credentials
def get_impersonate_credentials(
impersonate_service_account: str,
quoted_project_id: Optional[str] = None,
scopes: Optional[List[str]] = None,
lifetime: Optional[int] = None,
) -> impersonated_credentials.Credentials:
"""Get a impersonate credentials"""
# Create a impersonated service account
if scopes is None:
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
if lifetime is None:
# NOTE The maximum life time is 3600s. If we can't load a table within 1 hour,
# we have to consider alternative way.
lifetime = 3600
source_credentials, _ = auth.default()
if quoted_project_id is not None:
source_credentials, quoted_project_id = auth.default(
quota_project_id=quoted_project_id
)
target_credentials = impersonated_credentials.Credentials(
source_credentials=source_credentials,
target_principal=impersonate_service_account,
target_scopes=scopes,
lifetime=lifetime,
)
return target_credentials
```
--------------------------------------------------------------------------------
/src/mcp_vertexai_search/config.py:
--------------------------------------------------------------------------------
```python
from typing import List, Optional
import yaml
from pydantic import BaseModel, Field
class GenerateContentConfig(BaseModel):
"""The configuration for the generate content API."""
temperature: float = Field(
description="The temperature for the generate content API",
default=0.7,
)
top_p: float = Field(
description="The top p for the generate content API",
default=0.95,
)
class VertexAIModelConfig(BaseModel):
"""The configuration for a Vertex AI model."""
model_name: str = Field(..., description="The name of the Vertex AI model")
project_id: str = Field(..., description="The project ID of the Vertex AI model")
location: str = Field(..., description="The location of the model")
impersonate_service_account: Optional[str] = Field(
None, description="The service account to impersonate"
)
generate_content_config: Optional[GenerateContentConfig] = Field(
description="The configuration for the generate content API",
default_factory=GenerateContentConfig,
)
class DataStoreConfig(BaseModel):
"""The configuration for a Vertex AI data store."""
project_id: str = Field(
..., description="The project ID of the Vertex AI data store"
)
location: str = Field(..., description="The location of the Vertex AI data store")
datastore_id: str = Field(..., description="The ID of the Vertex AI data store")
tool_name: str = Field(
...,
description="The name of the tool. If not provided, defaults to 'search_document_<datastore_id>'",
)
description: str = Field(
description="The description of the Vertex AI data store",
default="",
)
class MCPServerConfig(BaseModel):
"""The configuration for an MCP server."""
name: str = Field(
description="The name of the MCP server", default="document-search"
)
class Config(BaseModel):
"""The configuration for the application."""
server: MCPServerConfig = Field(
description="The server configuration", default_factory=MCPServerConfig
)
model: VertexAIModelConfig = Field(
description="The model configuration", default_factory=VertexAIModelConfig
)
data_stores: List[DataStoreConfig] = Field(
description="The data stores configuration", default_factory=list
)
def load_yaml_config(file_path: str) -> Config:
"""Load a YAML config file"""
with open(file_path, "r") as f:
return Config(**yaml.safe_load(f))
```
--------------------------------------------------------------------------------
/src/mcp_vertexai_search/cli.py:
--------------------------------------------------------------------------------
```python
import asyncio
import click
import vertexai
from mcp_vertexai_search.agent import (
VertexAISearchAgent,
create_model,
create_vertex_ai_tools,
get_default_safety_settings,
get_generation_config,
get_system_instruction,
)
from mcp_vertexai_search.config import load_yaml_config
from mcp_vertexai_search.google_cloud import get_credentials
from mcp_vertexai_search.server import create_server, run_sse_server, run_stdio_server
cli = click.Group()
@cli.command("serve")
# trunk-ignore(bandit/B104)
@click.option("--host", type=str, default="0.0.0.0", help="The host to listen on")
@click.option("--port", type=int, default=8080, help="The port to listen on")
@click.option(
"--transport",
type=click.Choice(["stdio", "sse"]),
default="stdio",
help="The transport to use",
)
@click.option("--config", type=click.Path(exists=True), help="The config file")
def serve(
host: str,
port: int,
transport: str,
config: str,
):
server_config = load_yaml_config(config)
vertexai.init(
project=server_config.model.project_id, location=server_config.model.location
)
search_tools = create_vertex_ai_tools(server_config.data_stores)
model = create_model(
model_name=server_config.model.model_name,
tools=search_tools,
system_instruction=get_system_instruction(),
)
agent = VertexAISearchAgent(model=model)
app = create_server(agent, server_config)
if transport == "stdio":
asyncio.run(run_stdio_server(app))
elif transport == "sse":
asyncio.run(run_sse_server(app, host, port))
else:
raise ValueError(f"Invalid transport: {transport}")
@cli.command("search")
@click.option("--config", type=click.Path(exists=True), help="The config file")
@click.option("--query", type=str, help="The query to search for")
def search(
config: str,
query: str,
):
# Load the config
server_config = load_yaml_config(config)
# Initialize the Vertex AI client
credentials = get_credentials(
impersonate_service_account=server_config.model.impersonate_service_account,
)
vertexai.init(
project=server_config.model.project_id,
location=server_config.model.location,
credentials=credentials,
)
# Create the search agent
search_tools = create_vertex_ai_tools(server_config.data_stores)
model = create_model(
model_name=server_config.model.model_name,
tools=search_tools,
system_instruction=get_system_instruction(),
)
agent = VertexAISearchAgent(
model=model,
)
# Generate the response
generation_config = get_generation_config()
safety_settings = get_default_safety_settings()
response = agent.search(
query,
generation_config=generation_config,
safety_settings=safety_settings,
)
print(response)
@cli.command("validate-config")
@click.option("--config", type=click.Path(exists=True), help="The config file")
@click.option("--verbose", type=bool, default=False, help="Verbose output")
def validate_config(config: str, verbose: bool):
try:
server_config = load_yaml_config(config)
if verbose:
print(server_config.model_dump_json(indent=2))
# pylint: disable=broad-exception-caught
except Exception as e:
raise ValueError(f"Invalid config: {e}") from e
```
--------------------------------------------------------------------------------
/src/mcp_vertexai_search/server.py:
--------------------------------------------------------------------------------
```python
import anyio
import mcp.types as types
from mcp.server.lowlevel import Server
from mcp.shared.exceptions import ErrorData, McpError
from mcp_vertexai_search.agent import (
VertexAISearchAgent,
get_default_safety_settings,
get_generation_config,
)
from mcp_vertexai_search.config import Config
from mcp_vertexai_search.utils import to_mcp_tools_map
def create_server(
agent: VertexAISearchAgent,
config: Config,
) -> Server:
"""Create the MCP server."""
app = Server("document-search")
# Create a map of tools for the MCP server
tools_map = to_mcp_tools_map(config.data_stores)
# TODO Add @app.list_prompts()
@app.call_tool()
async def call_tool(
name: str, arguments: dict
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
if name not in tools_map:
raise McpError(
ErrorData(code=types.INVALID_PARAMS, message=f"Unknown tool: {name}")
)
if "query" not in arguments:
raise McpError(
ErrorData(code=types.INVALID_PARAMS, message="query is required")
)
# pylint: disable=broad-exception-caught
try:
# TODO handle retry logic
generation_config = get_generation_config(
temperature=config.model.generate_content_config.temperature,
top_p=config.model.generate_content_config.top_p,
)
safety_settings = get_default_safety_settings()
response = agent.search(
query=arguments["query"],
generation_config=generation_config,
safety_settings=safety_settings,
)
return [types.TextContent(type="text", text=response)]
# pylint: disable=broad-exception-caught
except Exception as e:
raise McpError(ErrorData(code=types.INVALID_PARAMS, message=str(e))) from e
@app.list_tools()
async def list_tools() -> list[types.Tool]:
return [tools_map[tool_name] for tool_name in tools_map]
return app
def run_stdio_server(app: Server) -> None:
"""Run the server using the stdio transport."""
try:
from mcp.server.stdio import stdio_server
except ImportError as e:
raise ImportError("stdio transport is not available") from e
async def arun():
async with stdio_server() as streams:
await app.run(streams[0], streams[1], app.create_initialization_options())
anyio.run(arun)
def run_sse_server(app: Server, host: str, port: int) -> None:
"""Run the server using the SSE transport."""
try:
import uvicorn
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.routing import Mount, Route
except ImportError as e:
raise ImportError("SSE transport is not available") from e
# Handle SSE connections
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await app.run(streams[0], streams[1], app.create_initialization_options())
# Create the Starlette app
starlette_app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
# Serve the Starlette app
uvicorn.run(starlette_app, host=host, port=port)
```
--------------------------------------------------------------------------------
/src/research_agent/chat.py:
--------------------------------------------------------------------------------
```python
import argparse
import asyncio
import json
import textwrap
from typing import List
from google import genai
from google.genai import chats, types
from loguru import logger
from pydantic import BaseModel, Field
from research_agent.mcp_client import MCPClient
from research_agent.utils import to_gemini_tool
class Reference(BaseModel):
"""A reference to a document."""
title: str = Field(..., description="The title of the document.")
raw_text: str = Field(..., description="The raw text of the document.")
class SearchResponse(BaseModel):
"""The response from the search tool."""
answer: str = Field(..., description="The answer to the user's question.")
references: List[Reference] = Field(
...,
description="The references to the documents that are used to answer the user's question.",
)
@classmethod
def from_json_string(cls, json_string: str) -> "SearchResponse":
"""Deserialize the search response from a JSON string."""
return cls(**json.loads(json_string))
def __str__(self) -> str:
return textwrap.dedent(f"""
Answer: {self.answer}
References:
{"\n".join([f" - {ref.title}: {ref.raw_text}" for ref in self.references])}
""")
async def process_query(
chat_client: chats.Chat,
mcp_client: MCPClient,
query: str,
) -> str:
"""Process the user query using Gemini and MCP tools."""
response = chat_client.send_message(message=[query])
if not response.candidates:
raise RuntimeError("No response from Gemini")
response_text = []
for candidate in response.candidates:
if not candidate.content:
logger.debug(f"No content in candidate {candidate}")
continue
for part in candidate.content.parts:
if part.text:
response_text.append(part.text)
elif part.function_call:
tool_name = part.function_call.name
tool_args = part.function_call.args
logger.debug(f"Tool name: {tool_name}, tool args: {tool_args}")
tool_call = await mcp_client.call_tool(tool_name, tool_args)
if tool_call and tool_call.content:
for content in tool_call.content:
text = content.text
if not text:
logger.info(f"No text in tool call content {content}")
continue
try:
parsed_content = SearchResponse.from_json_string(text)
response_text.append(str(parsed_content))
except Exception as e: # pylint: disable=broad-except
logger.error(
f"Failed to deserialize tool call content {content}: {e}"
)
response_text.append(text)
else:
raise RuntimeError(f"No tool call content {tool_call}")
else:
raise RuntimeError(f"Unknown part type {part}")
return "\n".join(response_text)
async def chat(server_url: str):
"""
Run the chat server.
"""
# Why do we use google-genai, not vertexai?
# Because it is easier to convert MCP tools to GenAI tools in google-genai.
genai_client = genai.Client(vertexai=True, location="us-central1")
mcp_client = MCPClient(name="document-search")
await mcp_client.connect_to_server(server_url=server_url)
# Collect tools from MCP server
mcp_tools = await mcp_client.list_tools()
# Convert MCP tools to GenAI tools
genai_tools = [to_gemini_tool(tool) for tool in mcp_tools.tools]
# Create chat client
chat_client = genai_client.chats.create(
model="gemini-2.0-flash",
config=types.GenerateContentConfig(
tools=genai_tools,
system_instruction="""
You are a helpful assistant to search documents.
You have to pass the query to the tool to search the documents as much natural as possible.
""",
),
)
print("If you want to quit, please enter 'bye'")
try:
while True:
# Get user query
query = input("Enter your query: ")
if query == "bye":
break
# Get response from GenAI
response = await process_query(chat_client, mcp_client, query)
print(response)
# pylint: disable=broad-except
except Exception as e:
await mcp_client.cleanup()
raise RuntimeError from e
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser()
# trunk-ignore(bandit/B104)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()
# Run the chat server
server_url = f"http://{args.host}:{args.port}/sse"
asyncio.run(chat(server_url))
```
--------------------------------------------------------------------------------
/tests/test_config.py:
--------------------------------------------------------------------------------
```python
import unittest
from src.mcp_vertexai_search.config import (
Config,
DataStoreConfig,
GenerateContentConfig,
MCPServerConfig,
VertexAIModelConfig,
)
class TestConfig(unittest.TestCase):
def test_default_config(self):
"""Test that default Config values are set correctly."""
config = Config(
model=VertexAIModelConfig(
project_id="test-project",
model_name="test-model",
location="test-location",
),
)
self.assertIsInstance(config.server, MCPServerConfig)
self.assertIsInstance(config.model, VertexAIModelConfig)
self.assertEqual(config.data_stores, [])
def test_custom_config(self):
"""Test that Config can be initialized with custom values."""
custom_server = MCPServerConfig(name="test-server")
custom_model = VertexAIModelConfig(
project_id="test-project",
model_name="test-model",
location="test-location",
)
custom_data_store = DataStoreConfig(
project_id="test-project",
location="test-location",
datastore_id="test-datastore",
tool_name="test-tool",
)
config = Config(
server=custom_server,
model=custom_model,
data_stores=[custom_data_store]
)
self.assertEqual(config.server.name, "test-server")
self.assertEqual(config.model.model_name, "test-model")
self.assertEqual(config.model.location, "test-location")
self.assertEqual(len(config.data_stores), 1)
self.assertEqual(config.data_stores[0].datastore_id, "test-datastore")
def test_default_mcpserverconfig(self):
"""Test MCPServerConfig default values."""
server_config = MCPServerConfig()
self.assertEqual(server_config.name, "document-search")
def test_custom_mcpserverconfig(self):
"""Test MCPServerConfig with custom values."""
server_config = MCPServerConfig(name="custom-server")
self.assertEqual(server_config.name, "custom-server")
def test_default_vertexaimodelconfig(self):
"""Test VertexAIModelConfig default values."""
model_config = VertexAIModelConfig(
project_id="test-project",
location="test-location",
model_name="test-model",
)
self.assertIsInstance(model_config.generate_content_config, GenerateContentConfig)
self.assertEqual(model_config.project_id, "test-project")
self.assertEqual(model_config.location, "test-location")
self.assertEqual(model_config.model_name, "test-model")
self.assertEqual(model_config.generate_content_config.temperature, 0.7)
self.assertEqual(model_config.generate_content_config.top_p, 0.95)
def test_custom_vertexaimodelconfig(self):
"""Test VertexAIModelConfig with custom values."""
custom_gen_config = GenerateContentConfig(temperature=0.8, top_p=0.9)
model_config = VertexAIModelConfig(
model_name="custom-model",
location="custom-location",
project_id="custom-project",
generate_content_config=custom_gen_config,
)
self.assertEqual(model_config.model_name, "custom-model")
self.assertEqual(model_config.location, "custom-location")
self.assertEqual(model_config.project_id, "custom-project")
self.assertEqual(model_config.generate_content_config.temperature, 0.8)
self.assertEqual(model_config.generate_content_config.top_p, 0.9)
def test_default_generatecontentconfig(self):
"""Test GenerateContentConfig default values."""
gen_config = GenerateContentConfig()
self.assertEqual(gen_config.temperature, 0.7)
self.assertEqual(gen_config.top_p, 0.95)
def test_custom_generatecontentconfig(self):
"""Test GenerateContentConfig with custom values."""
gen_config = GenerateContentConfig(temperature=0.6, top_p=0.8)
self.assertEqual(gen_config.temperature, 0.6)
self.assertEqual(gen_config.top_p, 0.8)
def test_default_datastoreconfig(self):
"""Test DataStoreConfig default values."""
datastore_config = DataStoreConfig(
project_id="test-project",
location="test-location",
datastore_id="test-datastore",
tool_name="test-tool",
)
self.assertEqual(datastore_config.description, "")
self.assertEqual(datastore_config.tool_name, "test-tool")
def test_custom_datastoreconfig(self):
"""Test DataStoreConfig with custom values."""
datastore_config = DataStoreConfig(
project_id="custom-project",
location="custom-location",
datastore_id="custom-datastore",
description="custom-description",
tool_name="custom-tool",
)
self.assertEqual(datastore_config.project_id, "custom-project")
self.assertEqual(datastore_config.location, "custom-location")
self.assertEqual(datastore_config.datastore_id, "custom-datastore")
self.assertEqual(datastore_config.description, "custom-description")
self.assertEqual(datastore_config.tool_name, "custom-tool")
def test_computed_tool_name_datastoreconfig(self):
"""Test DataStoreConfig computed tool name when not provided."""
datastore_config = DataStoreConfig(
project_id="custom-project",
location="custom-location",
datastore_id="custom-datastore",
description="custom-description",
tool_name = "document-search"
)
expected_tool_name = "document-search"
self.assertEqual(datastore_config.tool_name, expected_tool_name)
```
--------------------------------------------------------------------------------
/src/mcp_vertexai_search/agent.py:
--------------------------------------------------------------------------------
```python
import textwrap
from typing import List, Optional
from vertexai import generative_models
from mcp_vertexai_search.config import DataStoreConfig
# class Reference(BaseModel):
# """Reference"""
# title: str = Field(..., description="Title of the reference snippet")
# raw_text: str = Field(..., description="Content of the reference raw text")
# class SearchResponse(BaseModel):
# """Search response"""
# answer: str = Field(..., description="The answer to the query")
# references: List[Reference] = Field(
# ..., description="References used to generate the answer"
# )
def get_generation_config(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> generative_models.GenerationConfig:
"""Default generation config
TODO: We should customize this based on the use case.
"""
return generative_models.GenerationConfig(
temperature=temperature,
top_p=top_p,
response_mime_type="application/json",
)
def get_default_safety_settings() -> List[generative_models.SafetySetting]:
"""Default safety settings
TODO: We should customize this based on the use case.
"""
return [
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
]
def create_model(
model_name: str,
tools: List[generative_models.Tool],
system_instruction: str,
) -> generative_models.GenerativeModel:
return generative_models.GenerativeModel(
model_name=model_name,
tools=tools,
system_instruction=[
system_instruction,
],
)
def create_vertexai_search_tool(
project_id: str,
location: str,
datastore_id: str,
) -> generative_models.Tool:
"""Create a Vertex AI search tool"""
return generative_models.Tool.from_retrieval(
retrieval=generative_models.grounding.Retrieval(
source=generative_models.grounding.VertexAISearch(
project=project_id,
location=location,
datastore=datastore_id,
),
)
)
def create_vertex_ai_tools(
data_stores: List[DataStoreConfig],
) -> List[generative_models.Tool]:
"""Create a list of Vertex AI search tools"""
return [
create_vertexai_search_tool(
data_store.project_id, data_store.location, data_store.datastore_id
)
for data_store in data_stores
]
def get_system_instruction() -> str:
return textwrap.dedent(
"""
You are a helpful assistant knowledgeable about Alphabet quarterly earning reports.
Help users with their queries related to Alphabet by only responding with information available in the Grounding Knowledge store.
Respond in the same language as the user's query.
For instance, if the user's query is in Japanese, your response should be in Japanese.
- Always refer to the tool and ground your answers in it.
- Understand the retrieved snippet by the tool and only use that information to help users.
- For supporting references, you can provide the Grounding tool snippets verbatim, and any other info like page number.
- If information is not available in the tool, mention you don't have access to the information and do not try to make up an answer.
- Leave "references" as an empty list if you are unsure about the page and text snippet or if no relevant snippet is found.
- Output "answer" should be "I don't know" when the user question is irrelevant or outside the scope of the knowledge base.
The Grounding tool finds the most relevant snippets from the Alphabet earning reports data store.
Use the information provided by the tool as your knowledge base.
- ONLY use information available from the Grounding tool.
- DO NOT make up information or invent details not present in the retrieved snippets.
Response should ALWAYS be in the following JSON format:
## JSON schema
{
"answer": {
"type": "string",
"description": "The answer to the user's query"
},
"references": [
{
"title": {
"type": "string",
"description": "The title of the reference"
},
"raw_text": {
"type": "string",
"description": "The raw text in the reference"
}
}
]
}
"""
).strip()
class VertexAISearchAgent:
def __init__(
self,
model: generative_models.GenerativeModel,
):
# pylint: disable=line-too-long
self.model = model
async def asearch(
self,
query: str,
generation_config: generative_models.GenerationConfig,
safety_settings: Optional[List[generative_models.SafetySetting]],
) -> str:
"""Asynchronous search"""
response = await self.model.generate_content_async(
contents=[query],
generation_config=generation_config,
safety_settings=safety_settings,
stream=True,
)
return response.text
def search(
self,
query: str,
generation_config: generative_models.GenerationConfig,
safety_settings: Optional[List[generative_models.SafetySetting]],
) -> str:
"""Synchronous search"""
# TODO Enable to customize generation config and safety settings
response = self.model.generate_content(
contents=[query],
generation_config=generation_config,
safety_settings=safety_settings,
stream=False,
)
return response.text
```