# Directory Structure ``` ├── .github │ ├── CODEOWNERS │ ├── pull_request_template.md │ └── workflows │ └── release.yml ├── .gitignore ├── extend_ai_toolkit │ ├── __init__.py │ ├── __version__.py │ ├── crewai │ │ ├── __init__.py │ │ ├── extend_tool.py │ │ └── toolkit.py │ ├── examples │ │ ├── crewai-agent.py │ │ ├── langchain-react-agent.py │ │ └── openai-agent.py │ ├── langchain │ │ ├── __init__.py │ │ ├── extend_tool.py │ │ └── toolkit.py │ ├── modelcontextprotocol │ │ ├── __init__.py │ │ ├── client │ │ │ ├── __init__.py │ │ │ ├── anthropic_chat_client.py │ │ │ ├── chat_client.py │ │ │ ├── mcp_client.py │ │ │ └── openai_chat_client.py │ │ ├── main_sse.py │ │ ├── main.py │ │ ├── options.py │ │ └── server.py │ ├── notebooks │ │ └── langchain-react-agent.ipynb │ ├── openai │ │ ├── __init__.py │ │ ├── extend_tool.py │ │ └── toolkit.py │ ├── shared │ │ ├── __init__.py │ │ ├── agent_toolkit.py │ │ ├── api.py │ │ ├── configuration.py │ │ ├── enums.py │ │ ├── functions.py │ │ ├── helpers.py │ │ ├── interfaces.py │ │ ├── models.py │ │ ├── prompts.py │ │ ├── schemas.py │ │ ├── tools.py │ │ └── utils.py │ └── tests │ ├── __init__.py │ ├── test_configuration.py │ ├── test_crewai_toolkit.py │ ├── test_integration.py │ ├── test_langchain_toolkit.py │ ├── test_mcp_server.py │ ├── test_openai_toolkit.py │ ├── test_options.py │ └── test_validate_tool_spec.py ├── LICENSE ├── Makefile ├── pyproject.toml └── README.md ``` # Files -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class uv.lock # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv* ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ .DS_store .idea ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- ```markdown # Extend AI Toolkit ## Overview The [Extend](https://www.paywithextend.com) AI Toolkit provides a python based implementation of tools to integrate with Extend APIs for multiple AI frameworks including Anthropic's [Model Context Protocol (MCP)](https://modelcontextprotocol.com/), [OpenAI](https://github.com/openai/openai-agents-python), [LangChain](https://github.com/langchain-ai/langchain), and [CrewAI](https://github.com/joaomdmoura/crewAI). It enables users to delegate certain actions in the spend management flow to AI agents or MCP-compatible clients like Claude desktop. These tools are designed for existing Extend users with API keys. If you are not signed up with Extend and would like to learn more about our modern, easy-to-use virtual card and spend management platform for small- and medium-sized businesses, you can check us out at [paywithextend.com](https://www.paywithextend.com/). ## Features - **Multiple AI Framework Support**: Works with Anthropic Model Context Protocol, OpenAI Agents, LangChain LangGraph & ReAct, and CrewAI frameworks - **Comprehensive Tool Set**: Supports all of Extend's major API functionalities, spanning our Credit Card, Virtual Card, Transaction & Expense Management endpoints ## Installation You don't need this source code unless you want to modify the package. If you just want to use the package run: ```sh pip install extend_ai_toolkit ``` ### Requirements - **Python**: Version 3.10 or higher - **Extend API Key**: Sign up at [paywithextend.com](https://paywithextend.com) to obtain an API key - **Framework-specific Requirements**: - LangChain: `langchain` and `langchain-openai` packages - OpenAI: `openai` package - CrewAI: `crewai` package - Anthropic: `anthropic` package (for Claude) ## Configuration The library needs to be configured with your Extend API key and API, either through environment variables or command line arguments: ``` --api-key=your_api_key_here --api-secret=your_api_secret_here ``` or via environment variables: ``` EXTEND_API_KEY=your_api_key_here EXTEND_API_SECRET=your_api_secret_here ``` ## Available Tools The toolkit provides a comprehensive set of tools organized by functionality: ### Virtual Cards - `get_virtual_cards`: Fetch virtual cards with optional filters - `get_virtual_card_detail`: Get detailed information about a specific virtual card ### Credit Cards - `get_credit_cards`: List all credit cards - `get_credit_card_detail`: Get detailed information about a specific credit card ### Transactions - `get_transactions`: Fetch transactions with various filters - `get_transaction_detail`: Get detailed information about a specific transaction - `update_transaction_expense_data`: Update expense-related data for a transaction ### Expense Management - `get_expense_categories`: List all expense categories - `get_expense_category`: Get details of a specific expense category - `get_expense_category_labels`: Get labels for an expense category - `create_expense_category`: Create a new expense category - `create_expense_category_label`: Add a label to an expense category - `update_expense_category`: Modify an existing expense category - `create_receipt_attachment`: Upload a receipt (and optionally attach to a transaction) - `automatch_receipts`: Initiate async job to automatch uploaded receipts to transactions - `get_automatch_status`: Get the status of an automatch job - `send_receipt_reminder`: Send a reminder (via email) for a transaction missing a receipt ## Usage Examples ### Model Context Protocol The toolkit provides resources in the `extend_ai_toolkit.modelcontextprotocol` package to help you build an MCP server. #### Development Test Extend MCP server locally using MCP Inspector: ```bash npx @modelcontextprotocol/inspector python extend_ai_toolkit/modelcontextprotocol/main.py --tools=all ``` #### Claude Desktop Integration Add this tool as an MCP server to Claude Desktop by editing the config file: On MacOS: `~/Library/Application\ Support/Claude/claude_desktop_config.json` On Windows: `%APPDATA%/Claude/claude_desktop_config.json` if you want to use the create_receipt_attachment tool with claude desktop you'll need to install the filesystem mcp server via `npm install @modelcontextprotocol/server-filesystem` add then add to the config file as well. Please note: due to current limitations images uploaded directly to the Claude Desktop cannot be uploaded to Extend due to the fact that the Claude Desktop app does not have access to the underlying image data. This is why the [Filesystem MCP Server](https://github.com/modelcontextprotocol/servers/tree/main/src/filesystem) is necessary. With the addition of Filesystem, you can setup a dedicated folder for receipts, and tell Claude it to upload the receipt and automatch it to the most likely transaction. Alternatively, if you know the transaction you want to attach the receipt to then you can tell Claude to upload the receipt for that transaction (and skip the automatch process. ```json { "extend-mcp": { "command": "python", "args": [ "-m", "extend_ai_toolkit.modelcontextprotocol.main", "--tools=all" ], "env": { "EXTEND_API_KEY": "apik_XXXX", "EXTEND_API_SECRET": "XXXXX" } }, // optional: if you want to use the create_receipt_attachment tool "filesystem": { "command": "npx", "args": [ "-y", "@modelcontextprotocol/server-filesystem", "/path/to/receipts/folder" ] } } ``` #### Remote Execution You can also run your server remotely and communicate via SSE transport: ```bash python -m extend_ai_toolkit.modelcontextprotocol.main_sse --tools=all --api-key="apikey" --api-secret="apisecret" ``` and optionally connect using the MCP terminal client: ```bash python -m extend_ai_toolkit.modelcontextprotocol.client.mcp_client --mcp-server-host localhost --mcp-server-port 8000 --llm-provider=anthropic --llm-model=claude-3-5-sonnet-20241022 ``` ### OpenAI ```python import os from langchain_openai import ChatOpenAI from extend_ai_toolkit.openai.toolkit import ExtendOpenAIToolkit from extend_ai_toolkit.shared import Configuration, Scope, Product, Actions # Initialize the OpenAI toolkit extend_openai_toolkit = ExtendOpenAIToolkit.default_instance( api_key=os.environ.get("EXTEND_API_KEY"), api_secret=os.environ.get("EXTEND_API_SECRET"), configuration=Configuration( scope=[ Scope(Product.VIRTUAL_CARDS, actions=Actions(read=True)), Scope(Product.CREDIT_CARDS, actions=Actions(read=True)), Scope(Product.TRANSACTIONS, actions=Actions(read=True)), ] ) ) # Create an agent with the tools extend_agent = Agent( name="Extend Agent", instructions="You are an expert at integrating with Extend", tools=extend_openai_toolkit.get_tools() ) ``` ### LangChain ```python import os from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from extend_ai_toolkit.langchain.toolkit import ExtendLangChainToolkit from extend_ai_toolkit.shared import Configuration, Scope, Product, Actions # Initialize the LangChain toolkit extend_langchain_toolkit = ExtendLangChainToolkit.default_instance( api_key=os.environ.get("EXTEND_API_KEY"), api_secret=os.environ.get("EXTEND_API_SECRET"), configuration=Configuration( scope=[ Scope(Product.VIRTUAL_CARDS, actions=Actions(read=True)), Scope(Product.CREDIT_CARDS, actions=Actions(read=True)), Scope(Product.TRANSACTIONS, actions=Actions(read=True)), ] ) ) # Create tools for the agent tools = extend_langchain_toolkit.get_tools() # Create the agent executor langgraph_agent_executor = create_react_agent( ChatOpenAI(model="gpt-4"), tools ) ``` ### CrewAI ```python import os from extend_ai_toolkit.crewai.toolkit import ExtendCrewAIToolkit from extend_ai_toolkit.shared import Configuration, Scope, Product, Actions # Initialize the CrewAI toolkit toolkit = ExtendCrewAIToolkit.default_instance( api_key=os.environ.get("EXTEND_API_KEY"), api_secret=os.environ.get("EXTEND_API_SECRET"), configuration=Configuration( scope=[ Scope(Product.VIRTUAL_CARDS, actions=Actions(read=True)), Scope(Product.CREDIT_CARDS, actions=Actions(read=True)), Scope(Product.TRANSACTIONS, actions=Actions(read=True)), ] ) ) # Configure the LLM (using Claude) toolkit.configure_llm( model="claude-3-opus-20240229", api_key=os.environ.get("ANTHROPIC_API_KEY") ) # Create the Extend agent extend_agent = toolkit.create_agent( role="Extend Integration Expert", goal="Help users manage virtual cards, view credit cards, and check transactions efficiently", backstory="You are an expert at integrating with Extend, with deep knowledge of virtual cards, credit cards, and transaction management.", verbose=True ) # Create a task for handling user queries query_task = toolkit.create_task( description="Process and respond to user queries about Extend services", agent=extend_agent, expected_output="A clear and helpful response addressing the user's query", async_execution=True ) # Create a crew with the agent and task crew = toolkit.create_crew( agents=[extend_agent], tasks=[query_task], verbose=True ) # Run the crew result = crew.kickoff() ``` ## Contributing Contributions are welcome! Please feel free to submit a Pull Request. ## License This project is licensed under the MIT License - see the LICENSE file for details. ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/tests/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/__version__.py: -------------------------------------------------------------------------------- ```python __version__ = "1.2.0" ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/crewai/__init__.py: -------------------------------------------------------------------------------- ```python from .toolkit import ExtendCrewAIToolkit __all__ = ['ExtendCrewAIToolkit'] ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/openai/__init__.py: -------------------------------------------------------------------------------- ```python from .toolkit import ExtendOpenAIToolkit __all__ = [ "ExtendOpenAIToolkit", ] ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/langchain/__init__.py: -------------------------------------------------------------------------------- ```python from .toolkit import ExtendLangChainToolkit __all__ = [ "ExtendLangChainToolkit", ] ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/utils.py: -------------------------------------------------------------------------------- ```python def pop_first(iterable: list, predicate, default=None): for index, item in enumerate(iterable): if predicate(item): return iterable.pop(index) return default ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/__init__.py: -------------------------------------------------------------------------------- ```python from .options import Options, validate_options from .server import ExtendMCPServer from ..__version__ import __version__ as _version __version__ = _version __all__ = [ "ExtendMCPServer", "Options", "validate_options" ] ``` -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- ```markdown ## What is this PR doing? > Brief description of the changes in this pull request. Include the purpose and high level overview. ## Why do we need these changes? > Explain the reasoning behind these changes (feature, bug-fix, performance improvement, etc) ## Additional Notes > Anything else you would like to note here. ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/__init__.py: -------------------------------------------------------------------------------- ```python from .__version__ import __version__ as _version from .langchain import ExtendLangChainToolkit from .modelcontextprotocol import ExtendMCPServer, Options, validate_options from .openai import ExtendOpenAIToolkit __version__ = _version __all__ = [ "ExtendLangChainToolkit", "ExtendMCPServer", "ExtendOpenAIToolkit", "Options", "validate_options", ] ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/interfaces.py: -------------------------------------------------------------------------------- ```python from typing import Protocol from typing_extensions import TypeVar from .api import ExtendAPI from .tools import Tool ToolType = TypeVar("ToolType", covariant=True) AgentToolType = TypeVar("AgentToolType", bound="AgentToolInterface[ToolType]") class AgentToolInterface(Protocol[ToolType]): def __init__(self, extend_api: ExtendAPI, tool: Tool) -> None: ... def build_tool(self) -> ToolType: ... ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/client/__init__.py: -------------------------------------------------------------------------------- ```python from .anthropic_chat_client import AnthropicChatClient from .chat_client import ChatClient from .openai_chat_client import OpenAIChatClient __all__ = [ "AnthropicChatClient", "ChatClient", "MCPClient", "OpenAIChatClient", ] def __getattr__(name): if name == "MCPClient": from .mcp_client import MCPClient return MCPClient raise AttributeError(f"module {__name__} has no attribute {name}") ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/models.py: -------------------------------------------------------------------------------- ```python from dataclasses import dataclass from typing import Optional, TypedDict from .enums import Product class Actions(TypedDict, total=False): create: Optional[bool] update: Optional[bool] read: Optional[bool] @dataclass class Scope: type: Product actions: Actions @staticmethod def from_str(product_str: str, actions_str: str) -> "Scope": return Scope(Product(product_str), Actions(**{actions_str: True})) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/__init__.py: -------------------------------------------------------------------------------- ```python from . import functions from .agent_toolkit import AgentToolkit from .api import ExtendAPI from .configuration import Configuration, Product, Scope, Actions, validate_tool_spec from .enums import ExtendAPITools, Agent, Action from .interfaces import AgentToolInterface from .tools import Tool, tools __all__ = [ "Agent", "AgentToolInterface", "Configuration", "AgentToolkit", "ExtendAPI", "ExtendAPITools", "Tool", "Product", "Scope", "Action", "Actions", "tools", "functions", "validate_tool_spec", "helpers" ] ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/langchain/toolkit.py: -------------------------------------------------------------------------------- ```python from typing import Optional from extend_ai_toolkit.shared import ( AgentToolkit, Configuration, ExtendAPI, Tool ) from .extend_tool import ExtendTool class ExtendLangChainToolkit(AgentToolkit[ExtendTool]): def __init__( self, extend_api: ExtendAPI, configuration: Optional[Configuration] ): super().__init__( extend_api=extend_api, configuration=configuration or Configuration.all_tools() ) @classmethod def default_instance(cls, api_key: str, api_secret: str, configuration: Configuration) -> "ExtendLangChainToolkit": return cls( extend_api=ExtendAPI.default_instance(api_key, api_secret), configuration=configuration ) def tool_for_agent(self, api: ExtendAPI, tool: Tool) -> ExtendTool: return ExtendTool(api, tool) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/main.py: -------------------------------------------------------------------------------- ```python import sys from colorama import Fore from dotenv import load_dotenv from extend_ai_toolkit.modelcontextprotocol import ExtendMCPServer, Options from extend_ai_toolkit.shared import Configuration from extend_ai_toolkit.shared.configuration import VALID_SCOPES load_dotenv() def build_server(): options = Options.from_args((sys.argv[1:]), VALID_SCOPES) selected_tools = options.tools configuration = Configuration.from_tool_str(selected_tools) return ExtendMCPServer.default_instance( api_key=options.api_key, api_secret=options.api_secret, configuration=configuration ) def handle_error(error): sys.stderr.write(f"{Fore.YELLOW} {str(error)}\n") server = build_server() if __name__ == "__main__": try: server.run(transport='stdio') print("Extend MCP server is running.") except Exception as e: handle_error(e) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/openai/toolkit.py: -------------------------------------------------------------------------------- ```python from typing import Optional from agents import FunctionTool from extend_ai_toolkit.shared import ( AgentToolkit, Configuration, ExtendAPI, Tool ) from .extend_tool import ExtendTool class ExtendOpenAIToolkit(AgentToolkit[FunctionTool]): def __init__( self, extend_api: ExtendAPI, configuration: Optional[Configuration] ): super().__init__( extend_api=extend_api, configuration=configuration or Configuration.all_tools() ) @classmethod def default_instance(cls, api_key: str, api_secret: str, configuration: Configuration) -> "ExtendOpenAIToolkit": return cls( extend_api=ExtendAPI.default_instance(api_key, api_secret), configuration=configuration ) def tool_for_agent(self, api: ExtendAPI, tool: Tool) -> FunctionTool: return ExtendTool(api, tool) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/langchain/extend_tool.py: -------------------------------------------------------------------------------- ```python from typing import Any from langchain_core.tools import BaseTool from pydantic import Field from extend_ai_toolkit.shared import ExtendAPI, Tool class ExtendTool(BaseTool): """Tool for interacting with Extend API.""" extend_api: ExtendAPI = Field(description="The Extend API client") method: str = Field(description="The method to call on the Extend API") def __init__( self, extend_api: ExtendAPI, tool: Tool, ): super().__init__( name=tool.name, description=tool.description, args_schema=tool.args_schema, extend_api=extend_api, method=tool.method.value ) async def _arun( self, *args: Any, **kwargs: Any, ) -> str: return await self.extend_api.run(self.method, *args, **kwargs) def _run( self, *args: Any, **kwargs: Any, ) -> str: raise NotImplementedError("ExtendTool only supports async operations") ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/openai/extend_tool.py: -------------------------------------------------------------------------------- ```python import json from typing import Any from agents import FunctionTool from agents.run_context import RunContextWrapper from extend_ai_toolkit.shared import ExtendAPI, Tool def ExtendTool(api: ExtendAPI, tool: Tool) -> FunctionTool: async def on_invoke_tool(ctx: RunContextWrapper[Any], input_str: str) -> str: return await api.run(tool.method.value, **json.loads(input_str)) parameters = tool.args_schema.model_json_schema() parameters["additionalProperties"] = False parameters["type"] = "object" if "description" in parameters: del parameters["description"] if "title" in parameters: del parameters["title"] if "properties" in parameters: for prop in parameters["properties"].values(): if "title" in prop: del prop["title"] if "default" in prop: del prop["default"] return FunctionTool( name=tool.method.value, description=tool.description, params_json_schema=parameters, on_invoke_tool=on_invoke_tool, strict_json_schema=False ) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/agent_toolkit.py: -------------------------------------------------------------------------------- ```python from abc import abstractmethod from typing import List, Generic from pydantic import PrivateAttr from .api import ExtendAPI from .configuration import Configuration from .enums import Agent from .interfaces import ToolType from .tools import Tool, tools class AgentToolkit(Generic[ToolType]): _tools: List[ToolType] = PrivateAttr(default=[]) agent: Agent def __init__( self, extend_api: ExtendAPI, configuration: Configuration, ): super().__init__() self._tools = [ self.tool_for_agent(extend_api, tool) for tool in configuration.allowed_tools(tools) ] @classmethod def default_instance(cls, api_key: str, api_secret: str, configuration: Configuration) -> "AgentToolkit": return cls( extend_api=ExtendAPI.default_instance(api_key, api_secret), configuration=configuration ) @abstractmethod def tool_for_agent(self, api: ExtendAPI, tool: Tool) -> ToolType: raise NotImplementedError("Subclasses must implement tool_for_agent()") def get_tools(self) -> List[ToolType]: return self._tools ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/crewai/extend_tool.py: -------------------------------------------------------------------------------- ```python from typing import Any import asyncio from crewai.tools import BaseTool from pydantic import Field, ConfigDict from extend_ai_toolkit.shared import ExtendAPI, Tool class ExtendCrewAITool(BaseTool): """Tool for interacting with the Extend API in CrewAI.""" model_config = ConfigDict(arbitrary_types_allowed=True) extend_api: ExtendAPI = Field(description="The Extend API client") method: str = Field(description="The method to call on the Extend API") def __init__(self, api: ExtendAPI, tool: Tool): super().__init__( name=tool.method.value, description=tool.description, args_schema=tool.args_schema, extend_api=api, method=tool.method.value ) async def _arun(self, **kwargs: Any) -> str: """Run the tool asynchronously.""" return await self.extend_api.run(self.method, **kwargs) def _run(self, **kwargs: Any) -> str: """Run the tool synchronously by creating an event loop.""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(self._arun(**kwargs)) finally: loop.close() ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/examples/openai-agent.py: -------------------------------------------------------------------------------- ```python import os from agents import Agent, Runner from dotenv import load_dotenv from extend_ai_toolkit.openai.toolkit import ExtendOpenAIToolkit from extend_ai_toolkit.shared import Configuration, Scope, Product, Actions # Load environment variables load_dotenv() api_key = os.environ.get("EXTEND_API_KEY") api_secret = os.environ.get("EXTEND_API_SECRET") async def main(): extend_openai_toolkit = ExtendOpenAIToolkit.default_instance( api_key, api_secret, Configuration( scope=[ Scope(Product.VIRTUAL_CARDS, actions=Actions(read=True)), Scope(Product.CREDIT_CARDS, actions=Actions(read=True)), Scope(Product.TRANSACTIONS, actions=Actions(read=True)), ] ) ) extend_agent = Agent( name="Extend Agent", instructions="You are an expert at integrating with Extend. You can help users manage virtual cards, view credit cards, and check transactions.", tools=extend_openai_toolkit.get_tools(), model="gpt-4o", ) # Example interaction with the agent print("Welcome to the Extend OpenAI Agent! Type 'quit' to exit.") while True: user_input = input("\nYour question: ").strip() if user_input.lower() == 'quit': break response = await Runner.run(extend_agent, user_input) print("Agent response:", response.final_output) if __name__ == "__main__": import asyncio asyncio.run(main()) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/client/chat_client.py: -------------------------------------------------------------------------------- ```python from abc import ABC, abstractmethod from typing import List, Dict, Any, Tuple, Optional class ChatClient(ABC): """ Abstract base class for LLM API clients. Implementations handle specific provider APIs (OpenAI, Anthropic, etc.) """ @abstractmethod async def generate_completion( self, messages: List[Dict[str, Any]], functions: List[Dict[str, Any]], max_tokens: int) -> Tuple[Optional[str], Optional[Dict]]: """ Generate a completion from the LLM. Args: messages: List of message dictionaries with role and content functions: List of function definitions max_tokens: Maximum tokens to generate Returns: Tuple containing: - content: The text response if no function call (None if function call) - function_call: Dictionary with name and arguments if a function call is needed, or None if no function call """ pass @abstractmethod async def generate_with_tool_result( self, messages: List[Dict[str, Any]], max_tokens: int) -> str: """ Generate a follow-up completion after a tool call. Args: messages: List of message dictionaries including tool results max_tokens: Maximum tokens to generate Returns: Text response from the model """ pass ``` -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- ```yaml name: Release on: push: branches: - main permissions: contents: write jobs: build: runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v4 with: python-version: 3.11 - name: Install dependencies run: sudo apt-get install make - name: Create virtual environment run: make venv - name: Build package run: | set -x source venv/bin/activate rm -rf build dist *.egg-info make build ENV=stage - name: Extract Version from version file id: get_version run: | VERSION=$(grep -Po '^__version__\s*=\s*"\K[^"]+' extend_ai_toolkit/__version__.py) echo "Version extracted: $VERSION" echo "version=$VERSION" >> $GITHUB_OUTPUT - name: Create GitHub Release id: create_release uses: actions/create-release@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tag_name: v${{ steps.get_version.outputs.version }} release_name: v${{ steps.get_version.outputs.version }} draft: false prerelease: false - name: Install Twine run: | source venv/bin/activate pip install twine - name: Upload to PyPI env: PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} run: | source venv/bin/activate twine upload dist/* -u __token__ -p $PYPI_API_TOKEN ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- ```toml [project] name = "extend_ai_toolkit" description = "Extend AI Toolkit" authors = [ { name = "Extend Engineering", email = "[email protected]" }, ] dynamic = ["version"] keywords = ["extend", "api", "virtual cards", "payments", "ai", "agent", "mcp"] readme = "README.md" license = { text = "MIT" } requires-python = ">=3.10" dependencies = [ "mcp>=1.4.1", "mypy==1.15.0", "python-dotenv>=1.0.1", "langchain==0.3.20", "colorama>=0.4.4", "pydantic>=1.10.2", "requests==2.32.3", "build", "starlette>=0.40.0,<0.46.0", "openai>=1.66.3,<2.0.0", "openai-agents==0.0.4", "paywithextend==1.2.2", ] [project.urls] "Issue Tracker" = "https://github.com/paywithextend/extend-ai-toolkit/issues" "Source Code" = "https://github.com/paywithextend/extend-ai-toolkit" [project.optional-dependencies] dev = ["pytest>=7.0.1", "mypy>=1.11.1", "ruff>=0.6.1", "crewai>=0.108.0", "pytest-asyncio>=0.26.0"] [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.version] path = "extend_ai_toolkit/__version__.py" [tool.hatch.build.targets.wheel] packages = ["extend_ai_toolkit"] [tool.hatch.build] packages = ["extend_ai_toolkit"] exclude = ["extend_ai_toolkit/tests/**"] [tool.hatch.metadata] allow-direct-references = true [tool.ruff] lint.select = [ "E", # pycodestyle "F", # pyflakes "I", # isort "D", # pydocstyle "T201", "UP", ] lint.ignore = [ "UP006", "UP007", "UP035", "D417", "E501", ] [tool.ruff.lint.per-file-ignores] "extend_ai_toolkit/tests/*" = ["D", "UP", "I002"] [tool.ruff.lint.pydocstyle] convention = "google" ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/enums.py: -------------------------------------------------------------------------------- ```python from enum import Enum class ExtendAPITools(Enum): GET_VIRTUAL_CARDS = "get_virtual_cards" GET_VIRTUAL_CARD_DETAIL = "get_virtual_card_detail" CANCEL_VIRTUAL_CARD = "cancel_virtual_card" CLOSE_VIRTUAL_CARD = "close_virtual_card" GET_CREDIT_CARDS = "get_credit_cards" GET_CREDIT_CARD_DETAIL = "get_credit_card_detail" GET_TRANSACTIONS = "get_transactions" GET_TRANSACTION_DETAIL = "get_transaction_detail" UPDATE_TRANSACTION_EXPENSE_DATA = "update_transaction_expense_data" GET_EXPENSE_CATEGORIES = "get_expense_categories" GET_EXPENSE_CATEGORY = "get_expense_category" GET_EXPENSE_CATEGORY_LABELS = "get_expense_category_labels" CREATE_EXPENSE_CATEGORY = "create_expense_category" CREATE_EXPENSE_CATEGORY_LABEL = "create_expense_category_label" UPDATE_EXPENSE_CATEGORY = "update_expense_category" UPDATE_EXPENSE_CATEGORY_LABEL = "update_expense_category_label" PROPOSE_EXPENSE_CATEGORY_LABEL = "propose_expense_category_label" CONFIRM_EXPENSE_CATEGORY_LABEL = "confirm_expense_category_label" CREATE_RECEIPT_ATTACHMENT = "create_receipt_attachment" AUTOMATCH_RECEIPTS = "automatch_receipts" GET_AUTOMATCH_STATUS = "get_automatch_status" SEND_RECEIPT_REMINDER = "send_receipt_reminder" class Action(Enum): CREATE = "create" READ = "read" UPDATE = "update" class Agent(Enum): OPENAI = "openai" LANGCHAIN = "langchain" class Product(Enum): CREDIT_CARDS = "credit_cards" VIRTUAL_CARDS = "virtual_cards" TRANSACTIONS = "transactions" EXPENSE_CATEGORIES = "expense_categories" RECEIPT_ATTACHMENTS = "receipt_attachments" ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/tests/test_validate_tool_spec.py: -------------------------------------------------------------------------------- ```python import pytest from extend_ai_toolkit.shared import validate_tool_spec, Product, Actions def test_validate_tool_spec_valid(): # Valid input: "virtual_cards.read" product, action = validate_tool_spec("virtual_cards.read") assert product == Product.VIRTUAL_CARDS assert action == "read" # Another valid input: "credit_cards.create" product, action = validate_tool_spec("credit_cards.create") assert product == Product.CREDIT_CARDS assert action == "create" # Additional valid input: "expense_categories.read" product, action = validate_tool_spec("expense_categories.read") assert product == Product.EXPENSE_CATEGORIES assert action == "read" def test_validate_tool_spec_invalid_format(): # Missing dot should raise a ValueError. with pytest.raises(ValueError) as exc_info: validate_tool_spec("invalidformat") assert "must be in the format 'product.action'" in str(exc_info.value) def test_validate_tool_spec_invalid_product(): # Invalid product should raise a ValueError. with pytest.raises(ValueError) as exc_info: validate_tool_spec("nonexistent.read") # Check if error message mentions valid products. assert "Invalid product" in str(exc_info.value) def test_validate_tool_spec_invalid_actions(): # Invalid action should raise a ValueError. with pytest.raises(ValueError) as exc_info: validate_tool_spec("credit_cards.invalid") # Check if error message mentions valid action. valid_actions = list(Actions.__annotations__.keys()) assert "Invalid action" in str(exc_info.value) for perm in valid_actions: assert perm in str(exc_info.value) or str(valid_actions) in str(exc_info.value) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/client/openai_chat_client.py: -------------------------------------------------------------------------------- ```python # openai_client.py import os from typing import List, Dict, Any, Tuple, Optional from openai import AsyncOpenAI from extend_ai_toolkit.modelcontextprotocol.client import ChatClient class OpenAIChatClient(ChatClient): """Implementation of ChatClient for OpenAI's API""" def __init__(self, model_name="gpt-4o"): self.model_name = model_name self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) async def generate_completion( self, messages: List[Dict[str, Any]], functions: List[Dict[str, Any]], max_tokens: int ) -> Tuple[Optional[str], Optional[Dict]]: response = await self.client.chat.completions.create( model=self.model_name, max_tokens=max_tokens, messages=messages, functions=functions ) choice = response.choices[0] # Check if the assistant wants to call a function if choice.finish_reason == "function_call": func_call = choice.message.function_call function_call_info = { "name": func_call.name, "arguments": func_call.arguments } return None, function_call_info else: # No function call; return the assistant's message directly return choice.message.content, None async def generate_with_tool_result( self, messages: List[Dict[str, Any]], max_tokens: int ) -> str: follow_up = await self.client.chat.completions.create( model=self.model_name, max_tokens=max_tokens, messages=messages, ) return follow_up.choices[0].message.content ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/main_sse.py: -------------------------------------------------------------------------------- ```python import os import sys import uvicorn from colorama import Fore from dotenv import load_dotenv from mcp.server import Server from mcp.server.sse import SseServerTransport from starlette.applications import Starlette from starlette.requests import Request from starlette.routing import Mount, Route from extend_ai_toolkit.modelcontextprotocol import ExtendMCPServer, Options from extend_ai_toolkit.shared import Configuration from extend_ai_toolkit.shared.configuration import VALID_SCOPES load_dotenv() def build_starlette_app(sse_server: Server, *, debug: bool = False) -> Starlette: """Create a Starlette application that can serve the provided mcp server with SSE.""" sse = SseServerTransport("/messages/") async def handle_sse(request: Request) -> None: async with sse.connect_sse( request.scope, request.receive, request._send, # noqa: SLF001 ) as (read_stream, write_stream): await sse_server.run( read_stream, write_stream, mcp_server.create_initialization_options(), ) return Starlette( debug=debug, routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], ) def build_server(): options = Options.from_args((sys.argv[1:]), VALID_SCOPES) selected_tools = options.tools configuration = Configuration.from_tool_str(selected_tools) return ExtendMCPServer.default_instance( api_key=options.api_key, api_secret=options.api_secret, configuration=configuration ) server = build_server() def handle_error(error): sys.stderr.write(f"\n{Fore.RED} {str(error)}\n") if __name__ == "__main__": try: mcp_server = server._mcp_server import argparse host = os.environ.get("MCP_HOST", "127.0.0.1") port = os.environ.get("MCP_PORT", "8000") # Default to port 8000 if not set # Bind SSE request handling to MCP server starlette_app = build_starlette_app(mcp_server, debug=True) uvicorn.run(starlette_app, host=host, port=int(port)) except Exception as e: handle_error(e) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/options.py: -------------------------------------------------------------------------------- ```python import os def validate_options(cls): original_init = cls.__init__ def new_init(self, *args, **kwargs): original_init(self, *args, **kwargs) # Perform validation after initialization if not self.api_key: raise ValueError( 'Extend API key not provided. Please either pass it as an argument --api-key=$KEY or set the EXTEND_API_KEY environment variable.' ) elif not self.api_key.startswith("apik_"): raise ValueError('Extend API key must start with "apik_".') if not self.api_secret: raise ValueError( 'Extend API key not provided. Please either pass it as an argument --api-key=$KEY or set the EXTEND_API_SECRET environment variable.' ) if not self.tools: raise ValueError('The --tools argument must be provided.') cls.__init__ = new_init return cls @validate_options class Options: ACCEPTED_ARGS = ['api-key', 'api-secret', 'tools'] def __init__(self, tools, api_key, api_secret): self.tools = tools self.api_key = api_key self.api_secret = api_secret @staticmethod def from_args(args: list[str], valid_tools: list[str]) -> "Options": tools = "" api_key = None api_secret = None for arg in args: if arg.startswith("--"): arg_body = arg[2:] if "=" not in arg_body: raise ValueError(f"Argument {arg} is not in --key=value format.") key, value = arg_body.split("=", 1) match key: case "tools": tools = value case "api-key": api_key = value case "api-secret": api_secret = value case _: raise ValueError( f"Invalid argument: {key}. Accepted arguments are: {', '.join(Options.ACCEPTED_ARGS)}" ) for tool in tools.split(","): if tool.strip() == "all": continue if tool.strip() not in valid_tools: raise ValueError( f"Invalid tool: {tool}. Accepted tools are: {', '.join(valid_tools)}" ) api_key = api_key or os.environ.get("EXTEND_API_KEY") api_secret = api_secret or os.environ.get("EXTEND_API_SECRET") return Options(tools, api_key, api_secret) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/examples/langchain-react-agent.py: -------------------------------------------------------------------------------- ```python import os import asyncio from dotenv import load_dotenv from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from langchain_core.messages import SystemMessage, AIMessage, HumanMessage from extend_ai_toolkit.langchain.toolkit import ExtendLangChainToolkit from extend_ai_toolkit.shared import Configuration, Scope, Product, Actions # Load environment variables load_dotenv() # Get required environment variables api_key = os.environ.get("EXTEND_API_KEY") api_secret = os.environ.get("EXTEND_API_SECRET") # Validate environment variables if not all([api_key, api_secret]): raise ValueError("Missing required environment variables. Please set EXTEND_API_KEY and EXTEND_API_SECRET") llm = ChatOpenAI( model="gpt-4o", ) extend_langchain_toolkit = ExtendLangChainToolkit.default_instance( api_key, api_secret, Configuration( scope=[ Scope(Product.VIRTUAL_CARDS, actions=Actions(read=True,update=True)), Scope(Product.CREDIT_CARDS, actions=Actions(read=True)), Scope(Product.TRANSACTIONS, actions=Actions(read=True,update=True)), Scope(Product.EXPENSE_CATEGORIES, actions=Actions(read=True)), Scope(Product.RECEIPT_ATTACHMENTS, actions=Actions(read=True)), ] ) ) tools = [] tools.extend(extend_langchain_toolkit.get_tools()) # Create the react agent langgraph_agent_executor = create_react_agent( llm, tools ) async def chat_with_agent(): print("\nWelcome to the Extend AI Assistant!") print("You can ask me to:") print("- List all credit cards") print("- List all virtual cards") print("- Show details for a specific virtual card") print("- Show transactions for a specific period") print("- Show details for a specific transaction") print("\nType 'exit' to end the conversation.\n") while True: # Get user input user_input = input("\nYou: ").strip() if user_input.lower() in ['exit', 'quit', 'bye']: print("\nGoodbye!") break # Process the query result = await langgraph_agent_executor.ainvoke({ "input": user_input, "messages": [ SystemMessage(content="You are a helpful assistant that can interact with the Extend API to manage virtual cards, credit cards, and transactions."), HumanMessage(content=user_input) ] }) # Extract and print the assistant's message for message in result.get('messages', []): if isinstance(message, AIMessage): print("\nAssistant:", message.content) if __name__ == "__main__": asyncio.run(chat_with_agent()) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/crewai/toolkit.py: -------------------------------------------------------------------------------- ```python from typing import List, Optional from crewai import Agent, Task, Crew, LLM from crewai.tools import BaseTool from extend_ai_toolkit.shared import ( AgentToolkit, Configuration, ExtendAPI, Tool ) from .extend_tool import ExtendCrewAITool class ExtendCrewAIToolkit(AgentToolkit[BaseTool]): """Toolkit for integrating Extend API with CrewAI.""" def __init__( self, extend_api: ExtendAPI, configuration: Optional[Configuration] = None ): super().__init__( extend_api=extend_api, configuration=configuration ) self._llm = None @classmethod def default_instance(cls, api_key: str, api_secret: str, configuration: Configuration) -> "ExtendCrewAIToolkit": return cls( extend_api=ExtendAPI.default_instance(api_key, api_secret), configuration=configuration ) def configure_llm( self, model: str, api_key: Optional[str] = None, **kwargs ) -> None: """Configure the LLM for use with agents. Args: model: The model identifier to use (e.g., 'gpt-4', 'claude-3-opus-20240229') api_key: Optional API key for the model provider **kwargs: Additional arguments to pass to the LLM constructor """ self._llm = LLM( model=model, api_key=api_key, **kwargs ) def tool_for_agent(self, api: ExtendAPI, tool: Tool) -> BaseTool: """Convert an Extend tool to a CrewAI tool.""" return ExtendCrewAITool(api, tool) def create_agent( self, role: str, goal: str, backstory: str, tools: Optional[List[BaseTool]] = None, verbose: bool = True ) -> Agent: """Create a CrewAI agent with Extend tools.""" if tools is None: tools = self.get_tools() if self._llm is None: raise ValueError("No LLM configured. Call configure_llm() first.") return Agent( role=role, goal=goal, backstory=backstory, tools=tools, verbose=verbose, llm=self._llm ) def create_task( self, description: str, agent: Agent, expected_output: Optional[str] = None, async_execution: bool = True ) -> Task: """Create a CrewAI task.""" return Task( description=description, agent=agent, expected_output=expected_output, async_execution=async_execution ) def create_crew( self, agents: List[Agent], tasks: List[Task], verbose: bool = True ) -> Crew: """Create a CrewAI crew with agents and tasks.""" return Crew( agents=agents, tasks=tasks, verbose=verbose ) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/configuration.py: -------------------------------------------------------------------------------- ```python from typing import Optional, List from pydantic.v1 import BaseModel from .enums import Product, Action from .models import Scope, Actions from .tools import Tool from .utils import pop_first VALID_SCOPES = [ 'virtual_cards.read', 'virtual_cards.update', 'credit_cards.read', 'transactions.read', 'transactions.update', 'expense_categories.read', 'expense_categories.create', 'expense_categories.update', 'receipt_attachments.read', 'receipt_attachments.create' ] class Configuration(BaseModel): scope: Optional[List[Scope]] = None def add_scope(self, scope): if not self.scope: self.scope = [] self.scope.append(scope) def allowed_tools(self, tools) -> list[Tool]: return [tool for tool in tools if self.is_tool_in_scope(tool)] def is_tool_in_scope(self, tool: Tool) -> bool: if not self.scope: return False for tool_scope in tool.required_scope: configured_scope = next( filter(lambda x: x.type == tool_scope.type, self.scope), None ) if configured_scope is None: return False for action, required in tool_scope.actions.items(): if required and not configured_scope.actions.get(action, False): return False return True @classmethod def all_tools(cls) -> "Configuration": scopes: List[Scope] = [] for tool in VALID_SCOPES: product_str, action_str = tool.split(".") scope: Scope = pop_first( scopes, lambda x: x.type.value == product_str, default=None ) if scope: action = Action(action_str) scope.actions[action.value] = True scopes.append(scope) else: scope = Scope.from_str(product_str, action_str) scopes.append(scope) return cls(scope=scopes) @classmethod def from_tool_str(cls, tools: str) -> "Configuration": configuration = cls(scope=[]) tool_specs = tools.split(",") if tools else [] if "all" in tools: configuration = Configuration.all_tools() else: validated_tools = [] for tool_spec in tool_specs: validated_tools.append(validate_tool_spec(tool_spec)) for product, action_str in validated_tools: scope = Scope(product, Actions(**{action_str: True})) configuration.add_scope(scope) return configuration def validate_tool_spec(tool_spec: str) -> tuple[Product, str]: try: product_str, action = tool_spec.split(".") except ValueError: raise ValueError(f"Tool spec '{tool_spec}' must be in the format 'product.action'") try: product = Product(product_str) except ValueError: raise ValueError(f"Invalid product: '{product_str}'. Valid products are: {[p.value for p in Product]}") valid_actions = Actions.__annotations__.keys() if action not in valid_actions: raise ValueError(f"Invalid action: '{action}'. Valid actions are: {list(valid_actions)}") return product, action ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/examples/crewai-agent.py: -------------------------------------------------------------------------------- ```python """ Example implementation of an AI agent using the CrewAI framework. """ import os from pathlib import Path from dotenv import load_dotenv # Load environment variables from .env file, overriding any existing variables env_path = Path(__file__).parent.parent.parent / '.env' load_dotenv(env_path, override=True) import asyncio from extend_ai_toolkit.crewai.toolkit import ExtendCrewAIToolkit from extend_ai_toolkit.shared import Configuration, Scope, Product, Actions def validate_env_vars() -> tuple[str, str, str]: """Validate required environment variables. Returns: Tuple of (api_key, api_secret) Raises: ValueError: If any required environment variables are missing """ api_key = os.environ.get("EXTEND_API_KEY") api_secret = os.environ.get("EXTEND_API_SECRET") anthropic_key = os.environ.get("ANTHROPIC_API_KEY") if not all([api_key, api_secret, anthropic_key]): missing = [] if not api_key: missing.append("EXTEND_API_KEY") if not api_secret: missing.append("EXTEND_API_SECRET") if not anthropic_key: missing.append("ANTHROPIC_API_KEY") raise ValueError(f"Missing required environment variables: {', '.join(missing)}") return api_key, api_secret async def main(): try: # Validate environment variables api_key, api_secret = validate_env_vars() # Initialize the CrewAI toolkit toolkit = ExtendCrewAIToolkit.default_instance( api_key=api_key, api_secret=api_secret, configuration=Configuration( scope=[ Scope(Product.VIRTUAL_CARDS, actions=Actions(read=True)), Scope(Product.CREDIT_CARDS, actions=Actions(read=True)), Scope(Product.TRANSACTIONS, actions=Actions(read=True)), ] ) ) # Configure the LLM toolkit.configure_llm( model="claude-3-opus-20240229", api_key=os.environ.get("ANTHROPIC_API_KEY") ) # Create the Extend agent extend_agent = toolkit.create_agent( role="Extend Integration Expert", goal="Help users manage virtual cards, view credit cards, and check transactions efficiently", backstory="You are an expert at integrating with Extend, with deep knowledge of virtual cards, credit cards, and transaction management.", verbose=True ) # Create a task for handling user queries query_task = toolkit.create_task( description="Process and respond to user queries about Extend services", agent=extend_agent, expected_output="A clear and helpful response addressing the user's query", async_execution=True ) # Create a crew with the agent and task crew = toolkit.create_crew( agents=[extend_agent], tasks=[query_task], verbose=True ) # Example interaction with the agent print("Welcome to the Extend CrewAI Agent! Type 'quit' to exit.") while True: try: user_input = input("\nYour question: ").strip() if user_input.lower() == 'quit': break # Update task description with user input query_task.description = f"Process and respond to this user query: {user_input}" # Run the crew result = crew.kickoff() print("Agent response:", result) except Exception as e: print(f"Error processing query: {str(e)}") print("Please try again or type 'quit' to exit.") except Exception as e: print(f"Error initializing agent: {str(e)}") return if __name__ == "__main__": asyncio.run(main()) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/tests/test_langchain_toolkit.py: -------------------------------------------------------------------------------- ```python import inspect import json from unittest.mock import patch, Mock, AsyncMock import pytest from pydantic import BaseModel from extend_ai_toolkit.langchain.toolkit import ExtendLangChainToolkit from extend_ai_toolkit.shared import Configuration, ExtendAPITools, Tool, ExtendAPI # Define schema classes needed for testing class VirtualCardsSchema(BaseModel): page: int = 0 per_page: int = 10 class VirtualCardDetailSchema(BaseModel): virtual_card_id: str = "test_id" @pytest.fixture def mock_extend_api(): """Fixture that provides a mocked ExtendAPI instance""" with patch('extend_ai_toolkit.shared.agent_toolkit.ExtendAPI') as mock_api_class: mock_api_instance = Mock(spec=ExtendAPI) mock_api_instance.run = AsyncMock() mock_api_class.default_instance.return_value = mock_api_instance yield mock_api_class, mock_api_instance @pytest.fixture def mock_configuration(): """Fixture that provides a mocked Configuration instance with controlled tool permissions""" mock_config = Mock(spec=Configuration) # Create a list of allowed tools for testing allowed_tools = [ Tool( name="Get Virtual Cards", method=ExtendAPITools.GET_VIRTUAL_CARDS, description="Get all virtual cards", args_schema=VirtualCardsSchema, required_scope=[] ), Tool( name="Get Virtual Card Details", method=ExtendAPITools.GET_VIRTUAL_CARD_DETAIL, description="Get details of a virtual card", args_schema=VirtualCardDetailSchema, required_scope=[] ) ] # Configure the mock to return our controlled list of tools mock_config.allowed_tools.return_value = allowed_tools return mock_config @pytest.fixture def toolkit(mock_extend_api, mock_configuration): """Fixture that creates an ExtendLangChainToolkit instance with mocks""" mock_api_class, mock_api_instance = mock_extend_api toolkit = ExtendLangChainToolkit( extend_api=mock_api_instance, configuration=mock_configuration ) return toolkit def test_get_tools_returns_correct_tools(toolkit, mock_configuration): """Test that get_tools returns the correct set of tools""" tools = toolkit.get_tools() # We configured mock_configuration to return 2 tools assert len(tools) == 2 # Verify tool details assert tools[0].name == "get_virtual_cards" assert tools[0].description == "Get all virtual cards" assert tools[1].name == "get_virtual_card_detail" assert tools[1].description == "Get details of a virtual card" @pytest.mark.asyncio async def test_tool_execution_forwards_to_api(toolkit, mock_extend_api): """Test that tool execution correctly forwards requests to the API""" # Get the first tool tool = toolkit.get_tools()[0] # Set up a return value for the API call _, mock_api_instance = mock_extend_api mock_response = {"status": "success", "data": [{"id": "123"}]} mock_api_instance.run.return_value = mock_response # Call the tool result = await tool._arun(page=0, per_page=10) # Verify API was called correctly mock_api_instance.run.assert_called_once_with( ExtendAPITools.GET_VIRTUAL_CARDS.value, page=0, per_page=10 ) # Verify the result matches the mock response assert result == mock_response def test_tool_sync_execution_raises_error(toolkit): """Test that synchronous tool execution raises NotImplementedError""" # Get the first tool tool = toolkit.get_tools()[0] # Attempt synchronous execution with pytest.raises(NotImplementedError, match="ExtendTool only supports async operations"): tool._run(page=0, per_page=10) def test_tool_schema_matches_expected(toolkit): """Test that the tool has the correct schema""" # Get the first tool tool = toolkit.get_tools()[0] # Verify the tool has the correct schema class assert tool.args_schema == VirtualCardsSchema ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/tests/test_configuration.py: -------------------------------------------------------------------------------- ```python # test_configuration.py from dataclasses import dataclass from typing import List import pytest from extend_ai_toolkit.shared import ( Configuration, Product, Actions, tools, ) # Dummy implementations for testing: class ToolScope: """ Represents a scope requirement for a tool. """ def __init__(self, product_type: Product, actions: Actions): self.type = product_type self.actions = actions @dataclass class Tool: """ A dummy Tool for testing. It has a name and a list of required scopes. """ name: str required_scope: List[ToolScope] # Test that the classmethod all_tools creates a configuration with the expected defaults. def test_all_tools_configuration(): config = Configuration.all_tools() assert config.scope is not None # Build a mapping of product -> set of expected actions from tools expected_scopes = {} for tool in tools: for scope in tool.required_scope: product = scope.type if product not in expected_scopes: expected_scopes[product] = set() expected_scopes[product].update({k for k, v in scope.actions.items() if v}) # Validate the number of configured scopes assert len(config.scope) == len(expected_scopes) # Validate each product and its actions are in the config for pp in config.scope: expected_actions = expected_scopes.get(pp.type) assert expected_actions is not None, f"Unexpected product in config: {pp.type}" for action in expected_actions: assert pp.actions.get(action) is True, f"{pp.type}.{action} should be True" # Test is_tool_in_scope returns True when tool requirements match configuration. def test_is_tool_in_scope_success(): config = Configuration.all_tools() # Create a tool that requires credit_cards.read scope. tool_perm = ToolScope( product_type=Product.CREDIT_CARDS, actions=Actions(read=True) ) tool = Tool(name="Tool1", required_scope=[tool_perm]) # Assuming the default configuration for CREDIT_CARDS has read True. assert config.is_tool_in_scope(tool) is True # Test is_tool_in_scope returns False when a required scope is missing. def test_is_tool_in_scope_failure_missing_scope_action(): config = Configuration.all_tools() # For TRANSACTIONS, the default configuration allows read. # Here we require a 'create' action which is not allowed. tool_perm = ToolScope( product_type=Product.TRANSACTIONS, actions=Actions(create=True) ) tool = Tool(name="Tool2", required_scope=[tool_perm]) assert config.is_tool_in_scope(tool) is False # Test allowed_tools returns only the tools that meet the scope requirements. def test_allowed_tools(): config = Configuration.all_tools() # Tool1 meets its requirement (credit_cards with read True) tool1 = Tool( name="Tool1", required_scope=[ ToolScope( product_type=Product.CREDIT_CARDS, actions=Actions(read=True) ) ] ) # Tool2 does not meet its requirement (transactions with create True, but not allowed) tool2 = Tool( name="Tool2", required_scope=[ ToolScope( product_type=Product.TRANSACTIONS, actions=Actions(create=True) ) ] ) # Tool3 for expense categories; requires read access. tool3 = Tool( name="Tool3", required_scope=[ ToolScope( product_type=Product.EXPENSE_CATEGORIES, actions=Actions(read=True) ) ] ) # Get allowed tools from configuration. allowed = config.allowed_tools([tool1, tool2, tool3]) allowed_names = [tool.name for tool in allowed] # Tool1 and Tool3 should be allowed; Tool2 should not. assert "Tool1" in allowed_names assert "Tool3" in allowed_names assert "Tool2" not in allowed_names if __name__ == "__main__": pytest.main() ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/tests/test_options.py: -------------------------------------------------------------------------------- ```python # test_options.py import os import pytest from extend_ai_toolkit.modelcontextprotocol import Options, validate_options @pytest.fixture(autouse=True) def clear_environment_variables(): """Clear relevant environment variables before each test""" for var in ["EXTEND_API_KEY", "EXTEND_API_SECRET"]: if var in os.environ: del os.environ[var] yield def test_initialization(): """Test basic initialization with valid arguments""" options = Options( tools="tool1,tool2", api_key="apik_12345", api_secret="secret123" ) assert options.tools == "tool1,tool2" assert options.api_key == "apik_12345" assert options.api_secret == "secret123" def test_missing_api_key(): """Test validation when api_key is missing""" with pytest.raises(ValueError, match="Extend API key not provided"): Options( tools="tool1,tool2", api_key=None, api_secret="secret123" ) def test_invalid_api_key_format(): """Test validation when api_key has invalid format""" with pytest.raises(ValueError, match='Extend API key must start with "apik_"'): Options( tools="tool1,tool2", api_key="invalid_key", api_secret="secret123" ) def test_missing_api_secret(): """Test validation when api_secret is missing""" with pytest.raises(ValueError, match="Extend API key not provided"): Options( tools="tool1,tool2", api_key="apik_12345", api_secret=None ) def test_missing_tools(): """Test validation when tools is missing""" with pytest.raises(ValueError, match="The --tools argument must be provided"): Options( tools=None, api_key="apik_12345", api_secret="secret123" ) def test_from_args_with_env_vars(monkeypatch): """Test from_args using environment variables""" monkeypatch.setenv("EXTEND_API_KEY", "apik_env") monkeypatch.setenv("EXTEND_API_SECRET", "env_secret") options = Options.from_args(["--tools=tool1,tool2"], ["tool1", "tool2"]) assert options.tools == "tool1,tool2" assert options.api_key == "apik_env" assert options.api_secret == "env_secret" def test_from_args_with_cli_args(): """Test from_args using command line arguments""" options = Options.from_args([ "--api-key=apik_cli", "--api-secret=cli_secret", "--tools=tool1,tool2" ], ["tool1", "tool2"]) assert options.tools == "tool1,tool2" assert options.api_key == "apik_cli" assert options.api_secret == "cli_secret" def test_from_args_invalid_tool(): """Test from_args with invalid tool""" with pytest.raises(ValueError, match="Invalid tool: invalid_tool"): Options.from_args(["--tools=invalid_tool"], ["tool1", "tool2"]) def test_from_args_all_tools(): """Test from_args with 'all' as tool""" options = Options.from_args([ "--api-key=apik_12345", "--api-secret=secret123", "--tools=all" ], ["tool1", "tool2"]) assert options.tools == "all" def test_from_args_invalid_format(): """Test from_args with invalid argument format""" with pytest.raises(ValueError, match="is not in --key=value format"): Options.from_args(["--api-key"], ["tool1", "tool2"]) def test_from_args_invalid_argument(): """Test from_args with invalid argument name""" with pytest.raises(ValueError, match="Invalid argument: invalid"): Options.from_args(["--invalid=value"], ["tool1", "tool2"]) def test_validate_options_decorator(): """Test the validate_options decorator""" @validate_options class TestClass: def __init__(self, tools, api_key, api_secret): self.tools = tools self.api_key = api_key self.api_secret = api_secret # Should run without errors instance = TestClass( tools="tool1,tool2", api_key="apik_12345", api_secret="secret123" ) # Check that validation errors are raised with pytest.raises(ValueError): TestClass( tools=None, api_key="apik_12345", api_secret="secret123" ) if __name__ == "__main__": pytest.main() ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/tests/test_openai_toolkit.py: -------------------------------------------------------------------------------- ```python import inspect import json from unittest.mock import patch, Mock, AsyncMock import pytest from pydantic import BaseModel from agents import FunctionTool from extend_ai_toolkit.openai.toolkit import ExtendOpenAIToolkit from extend_ai_toolkit.shared import Configuration, ExtendAPITools, Tool, ExtendAPI # Define schema classes needed for testing class VirtualCardsSchema(BaseModel): page: int = 0 per_page: int = 10 class VirtualCardDetailSchema(BaseModel): virtual_card_id: str = "test_id" @pytest.fixture def mock_extend_api(): """Fixture that provides a mocked ExtendAPI instance""" with patch('extend_ai_toolkit.shared.agent_toolkit.ExtendAPI') as mock_api_class: mock_api_instance = Mock(spec=ExtendAPI) mock_api_instance.run = AsyncMock() mock_api_class.default_instance.return_value = mock_api_instance yield mock_api_class, mock_api_instance @pytest.fixture def mock_configuration(): """Fixture that provides a mocked Configuration instance with controlled tool permissions""" mock_config = Mock(spec=Configuration) # Create a list of allowed tools for testing allowed_tools = [ Tool( name="Get Virtual Cards", method=ExtendAPITools.GET_VIRTUAL_CARDS, description="Get all virtual cards", args_schema=VirtualCardsSchema, required_scope=[] ), Tool( name="Get Virtual Card Details", method=ExtendAPITools.GET_VIRTUAL_CARD_DETAIL, description="Get details of a virtual card", args_schema=VirtualCardDetailSchema, required_scope=[] ) ] # Configure the mock to return our controlled list of tools mock_config.allowed_tools.return_value = allowed_tools return mock_config @pytest.fixture def toolkit(mock_extend_api, mock_configuration): """Fixture that creates an ExtendOpenAIToolkit instance with mocks""" _, mock_api_instance = mock_extend_api toolkit = ExtendOpenAIToolkit( extend_api=mock_api_instance, configuration=mock_configuration ) return toolkit def test_get_tools_returns_correct_tools(toolkit, mock_configuration): """Test that get_tools returns the correct set of tools""" tools = toolkit.get_tools() # We configured mock_configuration to return 2 tools assert len(tools) == 2 # Verify tool details assert tools[0].name == ExtendAPITools.GET_VIRTUAL_CARDS.value assert tools[0].description == "Get all virtual cards" assert tools[1].name == ExtendAPITools.GET_VIRTUAL_CARD_DETAIL.value assert tools[1].description == "Get details of a virtual card" @pytest.mark.asyncio async def test_tool_execution_forwards_to_api(toolkit, mock_extend_api): """Test that tool execution correctly forwards requests to the API""" # Get the first tool tool = toolkit.get_tools()[0] # Set up a return value for the API call _, mock_api_instance = mock_extend_api mock_response = {"status": "success", "data": [{"id": "123"}]} mock_api_instance.run.return_value = mock_response # Call the tool input_str = json.dumps({"page": 0, "per_page": 10}) result = await tool.on_invoke_tool(None, input_str) # Verify API was called correctly mock_api_instance.run.assert_called_once_with( ExtendAPITools.GET_VIRTUAL_CARDS.value, page=0, per_page=10 ) # Verify the result matches the mock response assert result == mock_response def test_tool_schema_matches_expected(toolkit): """Test that the tool has the correct schema""" # Get the first tool tool = toolkit.get_tools()[0] # Get the expected schema expected_schema = VirtualCardsSchema.model_json_schema() expected_schema["additionalProperties"] = False expected_schema["type"] = "object" if "description" in expected_schema: del expected_schema["description"] if "title" in expected_schema: del expected_schema["title"] if "properties" in expected_schema: for prop in expected_schema["properties"].values(): if "title" in prop: del prop["title"] if "default" in prop: del prop["default"] # Verify the tool has the correct schema assert tool.params_json_schema == expected_schema ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/server.py: -------------------------------------------------------------------------------- ```python import inspect import logging from typing import Any from mcp.server import FastMCP from mcp.types import AnyFunction from extend_ai_toolkit.shared import Configuration from extend_ai_toolkit.shared import ExtendAPI from extend_ai_toolkit.shared import ExtendAPITools from extend_ai_toolkit.shared import functions from extend_ai_toolkit.shared import tools, Tool from ..__version__ import __version__ as _version logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class ExtendMCPServer(FastMCP): def __init__(self, extend_api: ExtendAPI, configuration: Configuration): super().__init__( name="Extend MCP Server", version=_version ) self._extend = extend_api for tool in configuration.allowed_tools(tools): fn: Any = None match tool.method.value: case ExtendAPITools.GET_VIRTUAL_CARDS.value: fn = functions.get_virtual_cards case ExtendAPITools.GET_VIRTUAL_CARD_DETAIL.value: fn = functions.get_virtual_card_detail case ExtendAPITools.CANCEL_VIRTUAL_CARD.value: fn = functions.cancel_virtual_card case ExtendAPITools.CLOSE_VIRTUAL_CARD.value: fn = functions.close_virtual_card case ExtendAPITools.GET_TRANSACTIONS.value: fn = functions.get_transactions case ExtendAPITools.GET_TRANSACTION_DETAIL.value: fn = functions.get_transaction_detail case ExtendAPITools.GET_CREDIT_CARDS.value: fn = functions.get_credit_cards case ExtendAPITools.GET_CREDIT_CARD_DETAIL.value: fn = functions.get_credit_card_detail case ExtendAPITools.GET_EXPENSE_CATEGORIES.value: fn = functions.get_expense_categories case ExtendAPITools.GET_EXPENSE_CATEGORY.value: fn = functions.get_expense_category case ExtendAPITools.GET_EXPENSE_CATEGORY_LABELS.value: fn = functions.get_expense_category_labels case ExtendAPITools.CREATE_EXPENSE_CATEGORY.value: fn = functions.create_expense_category case ExtendAPITools.CREATE_EXPENSE_CATEGORY_LABEL.value: fn = functions.create_expense_category_label case ExtendAPITools.UPDATE_EXPENSE_CATEGORY.value: fn = functions.update_expense_category case ExtendAPITools.UPDATE_EXPENSE_CATEGORY_LABEL.value: fn = functions.update_expense_category_label case ExtendAPITools.PROPOSE_EXPENSE_CATEGORY_LABEL.value: fn = functions.propose_transaction_expense_data case ExtendAPITools.CONFIRM_EXPENSE_CATEGORY_LABEL.value: fn = functions.confirm_transaction_expense_data case ExtendAPITools.UPDATE_TRANSACTION_EXPENSE_DATA.value: fn = functions.update_transaction_expense_data case ExtendAPITools.CREATE_RECEIPT_ATTACHMENT.value: fn = functions.create_receipt_attachment case ExtendAPITools.AUTOMATCH_RECEIPTS.value: fn = functions.automatch_receipts case ExtendAPITools.GET_AUTOMATCH_STATUS.value: fn = functions.get_automatch_status case ExtendAPITools.SEND_RECEIPT_REMINDER.value: fn = functions.send_receipt_reminder case _: raise ValueError(f"Invalid tool {tool}") self.add_tool( self._handle_tool_request(tool, fn), tool.name, tool.description ) @classmethod def default_instance(cls, api_key: str, api_secret: str, configuration: Configuration): return cls(extend_api=ExtendAPI.default_instance(api_key, api_secret), configuration=configuration) def _handle_tool_request(self, tool: Tool, fn: AnyFunction): async def resource_handler(*args, **kwargs): result = await self._extend.run(tool.method.value, *args, **kwargs) return { "content": [ { "type": "text", "text": str(result) } ] } orig_sig = inspect.signature(fn) new_params = list(orig_sig.parameters.values())[1:] resource_handler.__signature__ = inspect.Signature(new_params) return resource_handler ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/api.py: -------------------------------------------------------------------------------- ```python from dotenv import load_dotenv from .enums import ExtendAPITools from .functions import * from .helpers import * logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) load_dotenv() class ExtendAPI: """Wrapper around Extend API""" def __init__( self, extend: ExtendClient, ): self.extend = extend @classmethod def default_instance(cls, api_key: str, api_secret: str) -> "ExtendAPI": return cls( extend=ExtendClient( api_key=api_key, api_secret=api_secret ) ) async def run(self, tool: str, *args, **kwargs) -> str: match ExtendAPITools(tool).value: case ExtendAPITools.GET_VIRTUAL_CARDS.value: output = await get_virtual_cards(self.extend, *args, **kwargs) return format_virtual_cards_list(output) case ExtendAPITools.GET_VIRTUAL_CARD_DETAIL.value: output = await get_virtual_card_detail(self.extend, *args, **kwargs) return format_virtual_card_details(output) case ExtendAPITools.CANCEL_VIRTUAL_CARD.value: output = await cancel_virtual_card(self.extend, *args, **kwargs) return format_canceled_virtual_card(output) case ExtendAPITools.CLOSE_VIRTUAL_CARD.value: output = await close_virtual_card(self.extend, *args, **kwargs) return format_closed_virtual_card(output) case ExtendAPITools.GET_TRANSACTIONS.value: output = await get_transactions(self.extend, *args, **kwargs) return format_transactions_list(output) case ExtendAPITools.GET_TRANSACTION_DETAIL.value: output = await get_transaction_detail(self.extend, *args, **kwargs) return format_transaction_details(output) case ExtendAPITools.GET_CREDIT_CARDS.value: output = await get_credit_cards(self.extend, *args, **kwargs) return format_credit_cards_list(output) case ExtendAPITools.GET_CREDIT_CARD_DETAIL.value: output = await get_credit_card_detail(self.extend, *args, **kwargs) return format_credit_card_detail(output) case ExtendAPITools.GET_EXPENSE_CATEGORIES.value: output = await get_expense_categories(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.GET_EXPENSE_CATEGORY.value: output = await get_expense_category(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.GET_EXPENSE_CATEGORY_LABELS.value: output = await get_expense_category_labels(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.CREATE_EXPENSE_CATEGORY.value: output = await create_expense_category(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.CREATE_EXPENSE_CATEGORY_LABEL.value: output = await create_expense_category_label(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.UPDATE_EXPENSE_CATEGORY.value: output = await update_expense_category(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.UPDATE_EXPENSE_CATEGORY_LABEL.value: output = await update_expense_category_label(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.UPDATE_TRANSACTION_EXPENSE_DATA.value: output = await update_transaction_expense_data(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.PROPOSE_EXPENSE_CATEGORY_LABEL.value: output = await propose_transaction_expense_data(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.CONFIRM_EXPENSE_CATEGORY_LABEL.value: output = await confirm_transaction_expense_data(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.CREATE_RECEIPT_ATTACHMENT.value: output = await create_receipt_attachment(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.AUTOMATCH_RECEIPTS.value: output = await automatch_receipts(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.GET_AUTOMATCH_STATUS.value: output = await get_automatch_status(self.extend, *args, **kwargs) return json.dumps(output) case ExtendAPITools.SEND_RECEIPT_REMINDER.value: output = await send_receipt_reminder(self.extend, *args, **kwargs) return json.dumps(output) case _: raise ValueError(f"Invalid tool {tool}") ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/client/anthropic_chat_client.py: -------------------------------------------------------------------------------- ```python import json import os from typing import List, Dict, Any, Tuple, Optional from anthropic import AsyncAnthropic from .chat_client import ChatClient class AnthropicChatClient(ChatClient): """Implementation of ChatClient for Anthropic API""" def __init__( self, model_name="claude-3-7-sonnet-20250219", system_prompt="You are a helpful assistant."): self.model_name = model_name self.client = AsyncAnthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) self.system_prompt = system_prompt async def generate_completion( self, messages: List[Dict[str, Any]], functions: List[Dict[str, Any]], max_tokens: int ) -> Tuple[Optional[str], Optional[Dict]]: # Convert OpenAI-style messages to Anthropic format anthropic_messages = self._convert_messages(messages) # Convert OpenAI-style functions to Anthropic tools tools = self._convert_functions_to_tools(functions) response = await self.client.messages.create( model=self.model_name, max_tokens=max_tokens, messages=anthropic_messages, tools=tools, system=self.system_prompt ) # Process all content blocks and prioritize tool_use if present text_content = [] tool_use_info = None if response.content and len(response.content) > 0: for content_block in response.content: if content_block.type == "tool_use": # Get tool use data name = getattr(content_block, "name", None) input_data = getattr(content_block, "input", {}) # Convert input data to JSON string try: input_json = json.dumps(input_data) except TypeError: # Fallback for non-JSON-serializable objects if hasattr(input_data, "__dict__"): input_dict = input_data.__dict__ else: input_dict = {"data": str(input_data)} input_json = json.dumps(input_dict) tool_use_info = { "name": name if name else "unknown_tool", "arguments": input_json } elif content_block.type == "text": text_content.append(str(content_block.text)) # Combine all text content combined_text = "\n".join(text_content) if text_content else None # Prioritize tool_use if present if tool_use_info: return combined_text, tool_use_info else: return combined_text, None async def generate_with_tool_result( self, messages: List[Dict[str, Any]], max_tokens: int) -> str: # Convert OpenAI-style messages to Anthropic format anthropic_messages = self._convert_messages(messages) response = await self.client.messages.create( model=self.model_name, max_tokens=max_tokens, messages=anthropic_messages, system=self.system_prompt ) if response.content and len(response.content) > 0: text_blocks = [] for content_block in response.content: if content_block.type == "text": text_blocks.append(str(content_block.text)) return "\n".join(text_blocks) if text_blocks else "" return "" def _convert_messages(self, openai_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Convert OpenAI message format to Anthropic format""" anthropic_messages = [] tool_use_ids = {} # Map to store tool use IDs tool_use_count = 0 for i, msg in enumerate(openai_messages): if msg["role"] == "user": anthropic_messages.append({ "role": "user", "content": msg["content"] }) elif msg["role"] == "assistant": # Handle potential function calls in assistant messages if msg.get("function_call"): # Create a unique ID for this tool use tool_use_id = f"tool_{tool_use_count}" tool_use_count += 1 # Store the mapping for future tool results tool_use_ids[len(anthropic_messages)] = tool_use_id try: input_data = json.loads(msg["function_call"]["arguments"]) except json.JSONDecodeError: input_data = {} anthropic_messages.append({ "role": "assistant", "content": [{ "type": "tool_use", "id": tool_use_id, "name": msg["function_call"]["name"], "input": input_data }] }) else: anthropic_messages.append({ "role": "assistant", "content": msg["content"] }) elif msg["role"] == "function": # Convert function messages to tool responses # Find the corresponding tool use ID if available # Default to a generated ID if not found tool_use_id = tool_use_ids.get(len(anthropic_messages) - 1, f"tool_{tool_use_count}") tool_use_count += 1 anthropic_messages.append({ "role": "user", "content": [{ "type": "tool_result", "tool_use_id": tool_use_id, # Using tool_use_id as required by Anthropic API "content": msg["content"] }] }) return anthropic_messages def _convert_functions_to_tools(self, functions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Convert OpenAI function definitions to Anthropic tool format""" tools = [] for func in functions: tools.append({ "name": func["name"], "description": func.get("description", ""), "input_schema": func["parameters"] }) return tools ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/tests/test_mcp_server.py: -------------------------------------------------------------------------------- ```python import inspect from unittest.mock import patch, Mock, AsyncMock import pytest from mcp.server import FastMCP from pydantic import BaseModel from extend_ai_toolkit.modelcontextprotocol import ExtendMCPServer from extend_ai_toolkit.shared import Configuration, ExtendAPITools, Tool # Define schema classes needed for testing class InvalidToolSchema(BaseModel): dummy: str = "test" class VirtualCardsSchema(BaseModel): page: int = 0 per_page: int = 10 class VirtualCardDetailSchema(BaseModel): virtual_card_id: str = "test_id" @pytest.fixture def mock_extend_api(): """Fixture that provides a mocked ExtendAPI instance""" from extend_ai_toolkit.modelcontextprotocol import server original_api = server.ExtendAPI try: # Replace with mock mock_api_class = Mock() mock_api_instance = Mock() mock_api_instance.run = AsyncMock() mock_api_class.default_instance.return_value = mock_api_instance server.ExtendAPI = mock_api_class yield mock_api_class finally: # Restore original server.ExtendAPI = original_api @pytest.fixture def mock_configuration(): """Fixture that provides a mocked Configuration instance with controlled tool permissions""" mock_config = Mock(spec=Configuration) # Create a list of allowed tools for testing allowed_tools = [ Tool( name="Get Virtual Cards", method=ExtendAPITools.GET_VIRTUAL_CARDS, description="Get all virtual cards", args_schema=VirtualCardsSchema, required_scope=[] ), Tool( name="Get Virtual Card Details", method=ExtendAPITools.GET_VIRTUAL_CARD_DETAIL, description="Get details of a virtual card", args_schema=VirtualCardDetailSchema, required_scope=[] ) ] # Configure the mock to return our controlled list of tools mock_config.allowed_tools.return_value = allowed_tools return mock_config @pytest.fixture def mock_fastmcp(): """Fixture that patches the FastMCP parent class""" with patch.object(FastMCP, "__init__", return_value=None) as mock_init: with patch.object(FastMCP, "add_tool") as mock_add_tool: yield { "init": mock_init, "add_tool": mock_add_tool } @pytest.fixture def server(mock_extend_api, mock_configuration, mock_fastmcp): """Fixture that creates an ExtendMCPServer instance with mocks""" server = ExtendMCPServer.default_instance( api_key="test_api_key", api_secret="test_api_secret", configuration=mock_configuration ) # Attach the mocks for reference in tests server._mock_api = mock_extend_api server._mock_fastmcp = mock_fastmcp return server def test_init_calls_parent_constructor(mock_fastmcp): """Test that parent constructor is called with correct parameters""" # Create the server directly since we're testing initialization mock_config = Mock(spec=Configuration) # Configure allowed_tools to return an empty list (iterable) mock_config.allowed_tools.return_value = [] ExtendMCPServer.default_instance( api_key="test_api_key", api_secret="test_api_secret", configuration=mock_config ) # Verify the parent constructor was called with correct arguments mock_fastmcp["init"].assert_called_once_with( name="Extend MCP Server", version="1.1.0", ) def test_init_registers_allowed_tools(server, mock_configuration, mock_fastmcp): """Test that allowed tools are registered correctly""" # We configured mock_configuration to return 2 tools assert mock_fastmcp["add_tool"].call_count == 2 # Verify tool details for the first call args, kwargs = mock_fastmcp["add_tool"].call_args_list[0] assert args[1] == "get_virtual_cards" assert args[2] == "Get all virtual cards" # Verify tool details for the second call args, kwargs = mock_fastmcp["add_tool"].call_args_list[1] assert args[1] == "get_virtual_card_detail" assert args[2] == "Get details of a virtual card" @pytest.mark.asyncio async def test_handle_tool_request_forwards_to_api(server, mock_extend_api): """Test that the handler function correctly forwards requests to the API""" # Get the first mock tool mock_tool = server._mock_fastmcp["add_tool"].call_args_list[0][0][0] # Set up a return value for the API call server._mock_api.default_instance.return_value.run.return_value = {"status": "success", "data": [{"id": "123"}]} # Call the handler result = await mock_tool(page=0, per_page=10) # Verify API was called correctly server._mock_api.default_instance.return_value.run.assert_called_once_with( ExtendAPITools.GET_VIRTUAL_CARDS.value, page=0, per_page=10 ) # Verify the result is formatted correctly assert result == { "content": [ { "type": "text", "text": str({"status": "success", "data": [{"id": "123"}]}) } ] } def test_handler_signature_matches_function(server): """Test that the generated handler functions have the correct signature""" # Get the first handler function mock_handler = server._mock_fastmcp["add_tool"].call_args_list[0][0][0] # Inspect its signature sig = inspect.signature(mock_handler) # For get_virtual_cards, we expect parameters like page, per_page, etc. # (minus the first parameter which is usually 'self' or 'api') # Check that these parameters exist assert "page" in sig.parameters assert "per_page" in sig.parameters def test_match_statement_default_case(): """Test that the default case in a match statement raises ValueError""" # Create a mock tool mock_tool = Mock() mock_tool.name = "Test Tool" mock_tool.method.value = "non_existent_method" # Define a function that mimics the match statement in ExtendMCPServer.__init__ def match_function(tool): match tool.method.value: case ExtendAPITools.GET_VIRTUAL_CARDS.value: return "get_virtual_cards" case ExtendAPITools.GET_VIRTUAL_CARD_DETAIL.value: return "get_virtual_card_detail" # Add other cases as needed case _: raise ValueError(f"Invalid tool {tool}") # Test that the default case raises ValueError with pytest.raises(ValueError): match_function(mock_tool) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/tests/test_crewai_toolkit.py: -------------------------------------------------------------------------------- ```python import inspect import json import re from unittest.mock import patch, Mock, AsyncMock import pytest from pydantic import BaseModel from crewai import Agent, Task, Crew, LLM from crewai.tools import BaseTool from extend_ai_toolkit.crewai.toolkit import ExtendCrewAIToolkit from extend_ai_toolkit.shared import Configuration, ExtendAPITools, Tool, ExtendAPI # Define schema classes needed for testing class VirtualCardsSchema(BaseModel): page: int = 0 per_page: int = 10 class VirtualCardDetailSchema(BaseModel): virtual_card_id: str = "test_id" @pytest.fixture def mock_extend_api(): """Fixture that provides a mocked ExtendAPI instance""" with patch('extend_ai_toolkit.shared.agent_toolkit.ExtendAPI') as mock_api_class: mock_api_instance = Mock(spec=ExtendAPI) mock_api_instance.run = AsyncMock() mock_api_class.default_instance.return_value = mock_api_instance yield mock_api_class, mock_api_instance @pytest.fixture def mock_configuration(): """Fixture that provides a mocked Configuration instance with controlled tool permissions""" mock_config = Mock(spec=Configuration) # Create a list of allowed tools for testing allowed_tools = [ Tool( name="Get Virtual Cards", method=ExtendAPITools.GET_VIRTUAL_CARDS, description="Get all virtual cards", args_schema=VirtualCardsSchema, required_scope=[] ), Tool( name="Get Virtual Card Details", method=ExtendAPITools.GET_VIRTUAL_CARD_DETAIL, description="Get details of a virtual card", args_schema=VirtualCardDetailSchema, required_scope=[] ) ] # Configure the mock to return our controlled list of tools mock_config.allowed_tools.return_value = allowed_tools return mock_config @pytest.fixture def toolkit(mock_extend_api, mock_configuration): """Fixture that creates an ExtendCrewAIToolkit instance with mocks""" _, mock_api_instance = mock_extend_api toolkit = ExtendCrewAIToolkit( extend_api=mock_api_instance, configuration=mock_configuration ) return toolkit def test_get_tools_returns_correct_tools(toolkit, mock_configuration): """Test that get_tools returns the correct set of tools""" tools = toolkit.get_tools() # We configured mock_configuration to return 2 tools assert len(tools) == 2 # Verify tool details assert tools[0].name == ExtendAPITools.GET_VIRTUAL_CARDS.value assert "Tool Name: get_virtual_cards" in tools[0].description assert "Tool Description: Get all virtual cards" in tools[0].description assert tools[1].name == ExtendAPITools.GET_VIRTUAL_CARD_DETAIL.value assert "Tool Name: get_virtual_card_detail" in tools[1].description assert "Tool Description: Get details of a virtual card" in tools[1].description @pytest.mark.asyncio async def test_tool_execution_forwards_to_api(toolkit, mock_extend_api): """Test that tool execution correctly forwards requests to the API""" # Get the first tool tool = toolkit.get_tools()[0] # Set up a return value for the API call _, mock_api_instance = mock_extend_api mock_response = {"status": "success", "data": [{"id": "123"}]} mock_api_instance.run.return_value = mock_response # Call the tool result = await tool._arun(page=0, per_page=10) # Verify API was called correctly mock_api_instance.run.assert_called_once_with( ExtendAPITools.GET_VIRTUAL_CARDS.value, page=0, per_page=10 ) # Verify the result matches the mock response assert result == mock_response def test_tool_sync_execution_works(toolkit, mock_extend_api): """Test that synchronous tool execution works by creating an event loop""" # Get the first tool tool = toolkit.get_tools()[0] # Set up a return value for the API call _, mock_api_instance = mock_extend_api mock_response = {"status": "success", "data": [{"id": "123"}]} mock_api_instance.run.return_value = mock_response # Call the tool synchronously result = tool._run(page=0, per_page=10) # Verify API was called correctly mock_api_instance.run.assert_called_once_with( ExtendAPITools.GET_VIRTUAL_CARDS.value, page=0, per_page=10 ) # Verify the result matches the mock response assert result == mock_response def test_tool_schema_matches_expected(toolkit): """Test that the tool has the correct schema""" # Get the first tool tool = toolkit.get_tools()[0] # Verify the tool has the correct schema class assert tool.args_schema == VirtualCardsSchema def test_configure_llm(toolkit): """Test that LLM configuration works correctly""" toolkit.configure_llm( model="test-model", api_key="test-api-key", temperature=0.7 ) assert toolkit._llm is not None assert isinstance(toolkit._llm, LLM) assert toolkit._llm.model == "test-model" assert toolkit._llm.api_key == "test-api-key" assert toolkit._llm.temperature == 0.7 def test_create_agent_requires_llm(toolkit): """Test that creating an agent without configuring LLM raises an error""" with pytest.raises(ValueError, match=re.escape("No LLM configured. Call configure_llm() first.")): toolkit.create_agent( role="Test Role", goal="Test Goal", backstory="Test Backstory" ) def test_create_agent_with_llm(toolkit): """Test that agent creation works correctly with configured LLM""" toolkit.configure_llm(model="test-model", api_key="test-api-key") agent = toolkit.create_agent( role="Test Role", goal="Test Goal", backstory="Test Backstory", verbose=True ) assert isinstance(agent, Agent) assert agent.role == "Test Role" assert agent.goal == "Test Goal" assert agent.backstory == "Test Backstory" assert agent.verbose is True assert len(agent.tools) == 2 # From our mock configuration def test_create_task(toolkit): """Test that task creation works correctly""" toolkit.configure_llm(model="test-model", api_key="test-api-key") agent = toolkit.create_agent( role="Test Role", goal="Test Goal", backstory="Test Backstory" ) task = toolkit.create_task( description="Test Description", agent=agent, expected_output="Test Output", async_execution=True ) assert isinstance(task, Task) assert task.description == "Test Description" assert task.agent == agent assert task.expected_output == "Test Output" assert task.async_execution is True def test_create_crew(toolkit): """Test that crew creation works correctly""" toolkit.configure_llm(model="test-model", api_key="test-api-key") agent = toolkit.create_agent( role="Test Role", goal="Test Goal", backstory="Test Backstory" ) task = toolkit.create_task( description="Test Description", agent=agent, expected_output="Test Output" ) crew = toolkit.create_crew( agents=[agent], tasks=[task], verbose=True ) assert isinstance(crew, Crew) assert len(crew.agents) == 1 assert len(crew.tasks) == 1 assert crew.verbose is True ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/helpers.py: -------------------------------------------------------------------------------- ```python import json from typing import Dict # Helper functions for formatting responses def add_line(label, value): """Return a formatted line only if value is not None or 'N/A'.""" if value is not None and value != "N/A": return f" {label}: {value}\n" return "" def format_virtual_cards_list(response: Dict) -> str: """Format the virtual cards list response""" pagination = response.get("pagination", {}) cards = response.get("virtualCards", []) if not cards: return "No virtual cards found." result = f"Pagination:{json.dumps(pagination)}\n\nVirtual Cards:\n\n" for card in cards: result += ( f"- ID: {card['id']}\n" f" Name: {card['displayName']}\n" f" Status: {card['status']}\n" f" Balance: ${card['balanceCents'] / 100:.2f}\n" f" Expires: {card['expires']}\n\n" ) return result def format_canceled_virtual_card(response: Dict) -> str: """Format the canceled virtual card response""" card = response.get("virtualCard", {}) if not card: return "Virtual card not found." return ( f"Virtual Card Cancelled Successfully!\n\n" f"ID: {card['id']}\n" f"Name: {card['displayName']}\n" f"Status: {card['status']}\n" f"Balance: ${card['balanceCents'] / 100:.2f}\n" ) def format_closed_virtual_card(response: Dict) -> str: """Format the closed virtual card response""" card = response.get("virtualCard", {}) if not card: return "Virtual card not found." return ( f"Virtual Card Closed Successfully!\n\n" f"ID: {card['id']}\n" f"Name: {card['displayName']}\n" f"Status: {card['status']}\n" f"Final Balance: ${card['balanceCents'] / 100:.2f}\n" ) def format_virtual_card_details(response: Dict) -> str: """Format the detailed virtual card response""" card = response.get("virtualCard", {}) if not card: return "Virtual card not found." return ( f"Virtual Card Details:\n\n" f"ID: {card['id']}\n" f"Name: {card['displayName']}\n" f"Status: {card['status']}\n" f"Balance: ${card['balanceCents'] / 100:.2f}\n" f"Spent: ${card['spentCents'] / 100:.2f}\n" f"Limit: ${card['limitCents'] / 100:.2f}\n" f"Last 4: {card['last4']}\n" f"Expires: {card['expires']}\n" f"Valid From: {card['validFrom']}\n" f"Valid To: {card['validTo']}\n" f"Recipient: {card.get('recipientId', 'N/A')}\n" f"Notes: {card.get('notes', 'N/A')}\n" ) def format_credit_cards_list(response: Dict) -> str: """Format the credit cards list response""" cards = response.get("creditCards", []) if not cards: return "No credit cards found." result = "Available Credit Cards:\n\n" for card in cards: result += ( f"- ID: {card['id']}\n" f" Name: {card['displayName']}\n" f" Status: {card['status']}\n" f" Last 4: {card['last4']}\n" f" Issuer: {card['issuerName']}\n\n" ) return result def format_credit_card_detail(response: Dict) -> str: """Format the credit card detail response""" card = response.get("creditCard", {}) if not card: return "No credit card found." card_features = card['features'] or {} return ( f"Credit Card Details:\n\n" f"- ID: {card['id']}\n" f" Name: {card['displayName']}\n" f" Card User: {card['user']['firstName']} {card['user']['lastName']}\n" f" Is Budget: {card['parentCreditCardId'] is not None}\n" f" Status: {card['status']}\n" f" Last 4: {card['last4']}\n" f" Issuer: {card['issuerName']}\n" f" Guest Cards Enabled: {card_features['direct']}\n" f" Receipt Management Enabled: {card_features['receiptManagementEnabled']}\n" f" Receipt Capture Enabled: {card_features['receiptCaptureEnabled']}\n" f" Bill Pay Enabled: {card_features['billPay']}\n\n" ) def format_transactions_list(response: Dict) -> str: """Format the transactions list response""" # Handle case where response is error message if isinstance(response, str): return response # Get report data report = response.get("report", {}) transactions = report.get("transactions", []) if not transactions: return "No transactions found." # Add pagination info current_page = report.get("page", 1) total_pages = report.get("numPages", 1) per_page = report.get("per_page", 25) total_count = report.get("count", 0) result = f"Recent Transactions (Page {current_page} of {total_pages}, {total_count} total):\n\n" for txn in transactions: # Always include these required fields txn_id = txn.get('id') amount_cents = txn.get('clearingBillingAmountCents', txn.get('authBillingAmountCents', 0)) status = txn.get('status') # Start the transaction entry result += f"- ID: {txn_id}\n" result += f" Amount: ${amount_cents / 100:.2f}\n" result += f" Status: {status}\n" # Date can be under authedAt or clearedAt; skip if neither is provided txn_date = txn.get('authedAt', txn.get('clearedAt')) result += add_line("Date", txn_date) # Optional fields – add only if they have a valid value result += add_line("VCN ID", txn.get('virtualCardId')) result += add_line("VCN Name", txn.get('virtualCardDisplayName')) result += add_line("Cardholder Name", txn.get('cardholderName')) result += add_line("Recipient Name", txn.get('recipientName')) result += add_line("Merchant", txn.get('merchantName')) result += add_line("MCC", txn.get('mccDescription')) result += add_line("Notes", txn.get('notes')) result += add_line("Review Status", txn.get('reviewStatus')) result += add_line("Receipt Required", txn.get('receiptRequired')) result += add_line("Receipt Attachments Count", txn.get('attachmentsCount')) # For fields like connectedPlatforms that require some processing, # compute the value first synced_to_erp = True if txn.get('connectedPlatforms') and len(txn.get('connectedPlatforms')) > 0 else False result += add_line("Synced to ERP", synced_to_erp) # Optionally add a blank line or separator between transactions result += "\n" if current_page < total_pages: result += f"\nThere are more transactions available. Use page parameter to view next page." return result def format_transaction_details(response: Dict) -> str: """Format the transaction detail response""" txn = response if not txn: return "Transaction not found." amount = txn.get('clearingBillingAmountCents', txn.get('authBillingAmountCents', 0)) return ( f"Transaction Details:\n\n" f"ID: {txn['id']}\n" f"Merchant: {txn.get('merchantName', 'N/A')}\n" f"Amount: ${amount / 100:.2f}\n" f"Status: {txn['status']}\n" f"Type: {txn['type']}\n" f"Card: {txn.get('virtualCardId', 'N/A')}\n" f"Authorization Date: {txn.get('authedAt', 'N/A')}\n" f"Clearing Date: {txn.get('clearedAt', 'N/A')}\n" f"MCC: {txn.get('mcc', 'N/A')}\n" f"Notes: {txn.get('notes', 'N/A')}\n" ) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/modelcontextprotocol/client/mcp_client.py: -------------------------------------------------------------------------------- ```python import argparse import asyncio import json import logging import sys from contextlib import asynccontextmanager from typing import Optional, List, Dict, Any from dotenv import load_dotenv from mcp import ClientSession from mcp.client.sse import sse_client from mypy.util import json_dumps from extend_ai_toolkit.modelcontextprotocol.client import ( AnthropicChatClient, OpenAIChatClient, ChatClient ) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("mcp_client") load_dotenv() class MCPClient: """ Client for interacting with Model Capability Protocol (MCP) servers using Server-Sent Events (SSE) transport and the OpenAI API. """ def __init__(self, llm_client: ChatClient, model_name="gpt-4o", max_tokens=1000): self.session: Optional[ClientSession] = None self._session_context = None self._streams_context = None self.model_name = model_name self.max_tokens = max_tokens self.llm_client = llm_client @asynccontextmanager async def connect(self, server_url: str): """ Connect to MCP server with SSE transport as an async context manager. Args: server_url: URL of the SSE MCP server """ try: # Connect to SSE server self._streams_context = sse_client(url=server_url) streams = await self._streams_context.__aenter__() # Create client session self._session_context = ClientSession(*streams) self.session = await self._session_context.__aenter__() # Initialize session await self.session.initialize() # List available tools (for logging purposes) response = await self.session.list_tools() tool_names = [tool.name for tool in response.tools] logger.info(f"Connected to server with tools: {tool_names}") yield self except Exception as e: logger.error(f"Error connecting to SSE server: {str(e)}") raise finally: await self.cleanup() async def cleanup(self): """Properly clean up the session and streams""" if self._session_context: await self._session_context.__aexit__(None, None, None) self._session_context = None if self._streams_context: await self._streams_context.__aexit__(None, None, None) self._streams_context = None self.session = None async def list_available_tools(self) -> List[Dict[str, Any]]: """ Get a list of available tools from the MCP server. Returns: List of tool dictionaries with name, description, and input schema """ if not self.session: raise ConnectionError("Not connected to MCP server") response = await self.session.list_tools() return [{ "name": tool.name, "description": tool.description, "input_schema": tool.inputSchema } for tool in response.tools] async def process_query(self, query: str) -> str: """ Process a query using OpenAI's ChatCompletion endpoint. Args: query: User query string Returns: Response text from the OpenAI API. """ if not self.session: raise ConnectionError("Not connected to MCP server") messages = [{"role": "user", "content": query}] # Get available MCP tools and convert them into function definitions available_tools = await self.list_available_tools() functions = [] for tool in available_tools: # Convert your tool's input_schema to a valid JSON schema if needed functions.append({ "name": tool["name"], "description": tool["description"], "parameters": tool["input_schema"] }) final_text = [] try: # Call the LLM API content, function_call = await self.llm_client.generate_completion( messages=messages, functions=functions, max_tokens=self.max_tokens, ) if function_call: tool_name = function_call["name"] tool_arguments_str = function_call["arguments"] try: # Convert the JSON string into a dictionary tool_arguments = json.loads(tool_arguments_str) if tool_arguments_str else None except json.JSONDecodeError as e: logger.error(f"Error parsing tool arguments: {str(e)}") tool_arguments = None logger.info(f"Routing function call to tool: {tool_name} with args: {json_dumps(tool_arguments)}") # Call the corresponding tool on the MCP server tool_result = await self.session.call_tool(tool_name, tool_arguments) # Append the function call and tool result to the conversation history messages.append({ "role": "assistant", "content": None, "function_call": { "name": tool_name, "arguments": tool_arguments_str } }) messages.append({ "role": "function", "name": tool_name, "content": tool_result.content }) # Make a follow-up API call including the tool result assistant_message = await self.llm_client.generate_with_tool_result( messages=messages, max_tokens=self.max_tokens ) final_text.append(assistant_message) return "\n".join(final_text) else: # No function call; return the assistant's message directly final_text.append(content) return "\n".join(final_text) except Exception as e: error_msg = f"Error processing query: {str(e)}" logger.error(error_msg) return error_msg async def chat_loop(self): """Run an interactive chat loop""" if not self.session: raise ConnectionError("Not connected to MCP server") print("\nExtend MCP Client Started!") print("Enter your queries or type 'quit' to exit.") while True: try: await asyncio.sleep(0.1) sys.stdout.flush() query = input("\nQuery: ").strip() if query.lower() in ('quit', 'exit', 'q'): break print("Processing query...") response = await self.process_query(query) print("\nResponse:") print(response) except KeyboardInterrupt: print("\nExiting chat loop...") break except Exception as e: logger.error(f"Error in chat loop: {str(e)}") print(f"\nError: {str(e)}") async def main(): """Main entry point for the MCP client""" parser = argparse.ArgumentParser(description="MCP Client for interacting with SSE-based servers.") parser.add_argument("--llm-provider", type=str, choices=["openai", "anthropic"], default="openai", help="LLM Provider (e.g., openai)") parser.add_argument("--llm-model", type=str, help="LLM Model (e.g., gpt-4o, claude-3-5-sonnet-20240229)") parser.add_argument("--mcp-server-host", type=str, required=True, help="Server hostname (e.g., localhost)") parser.add_argument("--mcp-server-port", type=int, required=True, help="Server port (e.g., 8000)") parser.add_argument("--scheme", type=str, choices=["http", "https"], default="http", help="URL scheme (default: http)") args = parser.parse_args() server_url = f"{args.scheme}://{args.mcp_server_host}:{args.mcp_server_port}/sse" print(f"Connecting to: {server_url}") if args.llm_provider == "openai": model = args.llm_model or "gpt-4o" llm_client = OpenAIChatClient(model_name=model) else: model = args.llm_model or "claude-3-7-sonnet-20250219" llm_client = AnthropicChatClient(model_name=model) try: async with MCPClient( llm_client=llm_client ).connect(server_url=server_url) as client: await client.chat_loop() except KeyboardInterrupt: print("\nProgram terminated by user") except Exception as e: logger.error(f"Unhandled exception: {str(e)}") print(f"\nError: {str(e)}") sys.exit(1) if __name__ == "__main__": asyncio.run(main()) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/tools.py: -------------------------------------------------------------------------------- ```python from typing import List, TypedDict, Type from pydantic import BaseModel from .configuration import Scope, Product from .enums import ExtendAPITools from .prompts import ( get_virtual_cards_prompt, get_virtual_card_detail_prompt, cancel_virtual_card_prompt, close_virtual_card_prompt, get_transactions_prompt, get_transaction_detail_prompt, get_credit_cards_prompt, get_expense_categories_prompt, get_expense_category_prompt, get_expense_category_labels_prompt, create_expense_category_prompt, create_expense_category_label_prompt, update_expense_category_prompt, get_credit_card_detail_prompt, update_transaction_expense_data_prompt, create_receipt_attachment_prompt, get_automatch_status_prompt, automatch_receipts_prompt, send_receipt_reminder_prompt, ) from .schemas import ( GetVirtualCards, GetVirtualCardDetail, CancelVirtualCard, CloseVirtualCard, GetCreditCards, GetTransactions, GetTransactionDetail, GetExpenseCategories, GetExpenseCategory, GetExpenseCategoryLabels, CreateExpenseCategory, CreateExpenseCategoryLabel, UpdateExpenseCategory, GetCreditCardDetail, UpdateTransactionExpenseData, GetAutomatchStatusSchema, AutomatchReceiptsSchema, CreateReceiptAttachmentSchema, SendReceiptReminderSchema, ) class ActionDict(TypedDict): read: bool create: bool update: bool delete: bool class Tool(BaseModel): method: ExtendAPITools description: str args_schema: Type[BaseModel] required_scope: List[Scope] @property def name(self) -> str: return self.method.value tools: List[Tool] = [ Tool( method=ExtendAPITools.GET_VIRTUAL_CARDS, description=get_virtual_cards_prompt, args_schema=GetVirtualCards, required_scope=[ Scope( type=Product.VIRTUAL_CARDS, actions={"read": True}) ], ), Tool( method=ExtendAPITools.GET_VIRTUAL_CARD_DETAIL, description=get_virtual_card_detail_prompt, args_schema=GetVirtualCardDetail, required_scope=[ Scope( type=Product.VIRTUAL_CARDS, actions={"read": True}) ], ), Tool( method=ExtendAPITools.CANCEL_VIRTUAL_CARD, description=cancel_virtual_card_prompt, args_schema=CancelVirtualCard, required_scope=[ Scope( type=Product.VIRTUAL_CARDS, actions={ "read": True, "update": True, }) ], ), Tool( method=ExtendAPITools.CLOSE_VIRTUAL_CARD, description=close_virtual_card_prompt, args_schema=CloseVirtualCard, required_scope=[ Scope( type=Product.VIRTUAL_CARDS, actions={ "read": True, "update": True, }) ], ), Tool( method=ExtendAPITools.GET_CREDIT_CARDS, description=get_credit_cards_prompt, args_schema=GetCreditCards, required_scope=[ Scope( type=Product.CREDIT_CARDS, actions={ "read": True, }) ], ), Tool( method=ExtendAPITools.GET_CREDIT_CARD_DETAIL, description=get_credit_card_detail_prompt, args_schema=GetCreditCardDetail, required_scope=[ Scope( type=Product.CREDIT_CARDS, actions={"read": True} ) ], ), Tool( method=ExtendAPITools.GET_TRANSACTIONS, description=get_transactions_prompt, args_schema=GetTransactions, required_scope=[ Scope( type=Product.TRANSACTIONS, actions={ "read": True, }) ], ), Tool( method=ExtendAPITools.GET_TRANSACTION_DETAIL, description=get_transaction_detail_prompt, args_schema=GetTransactionDetail, required_scope=[ Scope( type=Product.TRANSACTIONS, actions={ "read": True, }) ], ), Tool( method=ExtendAPITools.UPDATE_TRANSACTION_EXPENSE_DATA, description=update_transaction_expense_data_prompt, args_schema=UpdateTransactionExpenseData, required_scope=[ Scope( type=Product.TRANSACTIONS, actions={ "read": True, "update": True, } ) ], ), Tool( method=ExtendAPITools.GET_EXPENSE_CATEGORIES, description=get_expense_categories_prompt, args_schema=GetExpenseCategories, required_scope=[ Scope( type=Product.EXPENSE_CATEGORIES, actions={"read": True} ) ], ), Tool( method=ExtendAPITools.GET_EXPENSE_CATEGORY, description=get_expense_category_prompt, args_schema=GetExpenseCategory, required_scope=[ Scope( type=Product.EXPENSE_CATEGORIES, actions={"read": True} ) ], ), Tool( method=ExtendAPITools.GET_EXPENSE_CATEGORY_LABELS, description=get_expense_category_labels_prompt, args_schema=GetExpenseCategoryLabels, required_scope=[ Scope( type=Product.EXPENSE_CATEGORIES, actions={"read": True} ) ], ), Tool( method=ExtendAPITools.CREATE_EXPENSE_CATEGORY, description=create_expense_category_prompt, args_schema=CreateExpenseCategory, required_scope=[ Scope( type=Product.EXPENSE_CATEGORIES, actions={"read": True, "create": True} ) ], ), Tool( method=ExtendAPITools.CREATE_EXPENSE_CATEGORY_LABEL, description=create_expense_category_label_prompt, args_schema=CreateExpenseCategoryLabel, required_scope=[ Scope( type=Product.EXPENSE_CATEGORIES, actions={"read": True, "create": True} ) ], ), Tool( method=ExtendAPITools.UPDATE_EXPENSE_CATEGORY, description=update_expense_category_prompt, args_schema=UpdateExpenseCategory, required_scope=[ Scope( type=Product.EXPENSE_CATEGORIES, actions={"read": True, "update": True} ) ], ), Tool( method=ExtendAPITools.CREATE_RECEIPT_ATTACHMENT, description=create_receipt_attachment_prompt, args_schema=CreateReceiptAttachmentSchema, required_scope=[ Scope( type=Product.RECEIPT_ATTACHMENTS, actions={"read": True, "create": True} ), Scope( type=Product.TRANSACTIONS, actions={"read": True, "update": True} ) ], ), Tool( method=ExtendAPITools.AUTOMATCH_RECEIPTS, description=automatch_receipts_prompt, args_schema=AutomatchReceiptsSchema, required_scope=[ Scope( type=Product.RECEIPT_ATTACHMENTS, actions={"read": True} ), Scope( type=Product.TRANSACTIONS, actions={"read": True, "update": True} ) ], ), Tool( method=ExtendAPITools.GET_AUTOMATCH_STATUS, description=get_automatch_status_prompt, args_schema=GetAutomatchStatusSchema, required_scope=[ Scope( type=Product.RECEIPT_ATTACHMENTS, actions={"read": True} ) ], ), Tool( method=ExtendAPITools.SEND_RECEIPT_REMINDER, description=send_receipt_reminder_prompt, args_schema=SendReceiptReminderSchema, required_scope=[ Scope( type=Product.RECEIPT_ATTACHMENTS, actions={"read": True} ), Scope( type=Product.TRANSACTIONS, actions={"read": True} ) ], ), # Tool( # method=ExtendAPITools.PROPOSE_EXPENSE_CATEGORY_LABEL, # name="propose_transaction_expense_data", # description=propose_transaction_expense_data_prompt, # args_schema=ProposeTransactionExpenseData, # required_scope=[ # Scope( # type=Product.TRANSACTIONS, # actions={"read": True} # ), # Scope( # type=Product.EXPENSE_CATEGORIES, # actions={"read": True} # ) # ], # ), # Tool( # method=ExtendAPITools.CONFIRM_EXPENSE_CATEGORY_LABEL, # name="confirm_transaction_expense_data", # description=confirm_transaction_expense_data_prompt, # args_schema=ConfirmTransactionExpenseData, # required_scope=[ # Scope( # type=Product.TRANSACTIONS, # actions={"read": True, "update": True} # ), # Scope( # type=Product.EXPENSE_CATEGORIES, # actions={"read": True} # ) # ], # ) ] ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/schemas.py: -------------------------------------------------------------------------------- ```python from typing import Dict, Optional, List from pydantic import BaseModel, Field class GetVirtualCards(BaseModel): """Schema for the `get_virtual_cards` operation.""" page: int = Field( 0, description="Pagination page number, default is 0." ) per_page: int = Field( 10, description="Number of items per page, default is 10." ) status: Optional[str] = Field( None, description="Filter virtual cards by status. Options: ACTIVE, CANCELLED, PENDING, EXPIRED, CLOSED, CONSUMED." ) recipient: Optional[str] = Field( None, description="Filter virtual cards by recipient identifier." ) search_term: Optional[str] = Field( None, description="Search term to filter virtual cards." ) sort_field: Optional[str] = Field( None, description="Field to sort by: 'createdAt', 'updatedAt', 'balanceCents', 'displayName', 'type', or 'status'." ) sort_direction: Optional[str] = Field( None, description="Sort direction, ASC or DESC." ) class GetVirtualCardDetail(BaseModel): """Schema for the `get_virtual_card_detail` operation.""" virtual_card_id: str = Field( ..., description="The ID of the virtual card." ) class CloseVirtualCard(BaseModel): """Schema for the `close_virtual_card` operation.""" virtual_card_id: str = Field( ..., description="The ID of the virtual card to close." ) class CancelVirtualCard(BaseModel): """Schema for the `cancel_virtual_card` operation.""" virtual_card_id: str = Field( ..., description="The ID of the virtual card to cancel." ) class GetTransactions(BaseModel): """Schema for the `get_transactions` operation.""" page: int = Field( 0, description="Pagination page number, default is 0." ) per_page: int = Field( 50, description="Number of transactions per page, default is 50." ) from_date: Optional[str] = Field( None, description="Start date to filter transactions (YYYY-MM-DD)." ) to_date: Optional[str] = Field( None, description="End date to filter transactions (YYYY-MM-DD)." ) status: Optional[str] = Field( None, description="Filter transactions by status (e.g., PENDING, CLEARED, DECLINED, etc.)." ) virtual_card_id: Optional[str] = Field( None, description="Filter transactions by a specific virtual card ID." ) min_amount_cents: Optional[int] = Field( None, description="Minimum transaction amount in cents." ) max_amount_cents: Optional[int] = Field( None, description="Maximum transaction amount in cents." ) receipt_missing: Optional[bool] = Field( None, description="Filter transactions by whether they are missing a receipt." ) search_term: Optional[str] = Field( None, description="Filter transactions by search term." ) sort_field: Optional[str] = Field( None, description="Field to sort by, with optional direction. Use 'recipientName', 'merchantName', 'amount', 'date' for ASC. Use '-recipientName', '-merchantName', '-amount', '-date' for DESC." ) class GetTransactionDetail(BaseModel): """Schema for the `get_transaction_detail` operation.""" transaction_id: str = Field( ..., description="The ID of the transaction to retrieve details for." ) class ProposeTransactionExpenseData(BaseModel): """Schema for the `propose_transaction_expense_data` operation.""" transaction_id: str = Field( ..., description="The unique identifier of the transaction." ) data: Dict = Field( ..., description=( "A dictionary representing the expense details to propose. " "Expected format: {'expenseDetails': [{'categoryId': str, 'labelId': str}]}." ) ) class ProposeTransactionExpenseDataResponse(BaseModel): """Response schema for the `propose_transaction_expense_data` operation.""" status: str = Field( default="pending_confirmation", description="Status of the expense data proposal." ) transaction_id: str = Field( ..., description="The unique identifier of the transaction." ) confirmation_token: str = Field( ..., description="The unique token required to confirm this expense data update." ) expires_at: str = Field( ..., description="ISO-8601 timestamp when this proposal expires." ) proposed_categories: List[Dict] = Field( ..., description="List of proposed expense categories and labels." ) class ConfirmTransactionExpenseData(BaseModel): """Schema for the `confirm_transaction_expense_data` operation.""" confirmation_token: str = Field( ..., description="The unique token from the proposal step that was shared with the user." ) class UpdateTransactionExpenseData(BaseModel): """Schema for the `update_transaction_expense_data` operation.""" transaction_id: str = Field( ..., description="The unique identifier of the transaction." ) user_confirmed_data_values: bool = Field( ..., description="Indicates whether or not the user has confirmed the specific values used in the data argument." ) data: Dict = Field( ..., description=( "A dictionary representing the expense details to update. " "Expected format: {'expenseDetails': [{'categoryId': str, 'labelId': str}]}." ) ) class GetCreditCards(BaseModel): """Schema for the `get_credit_cards` operation.""" page: int = Field( 0, description="Pagination page number, default is 0." ) per_page: int = Field( 10, description="Number of credit cards per page, default is 10." ) status: Optional[str] = Field( None, description="Filter credit cards by status." ) search_term: Optional[str] = Field( None, description="Search term to filter credit cards." ) sort_direction: Optional[str] = Field( None, description="Sort direction, ASC or DESC." ) class GetCreditCardDetail(BaseModel): """Schema for the `get_credit_card_detail` operation.""" credit_card_id: str = Field( ..., description="The ID of the credit card to retrieve details for." ) class GetExpenseCategories(BaseModel): """Schema for the `get_expense_categories` operation.""" active: Optional[bool] = Field( None, description="Filter categories by active status." ) required: Optional[bool] = Field( None, description="Filter categories by required status." ) search: Optional[str] = Field( None, description="Search term to filter categories." ) sort_field: Optional[str] = Field( None, description="Field to sort the categories by." ) sort_direction: Optional[str] = Field( None, description="Direction to sort the categories (ASC or DESC)." ) class GetExpenseCategory(BaseModel): """Schema for the `get_expense_category` operation.""" category_id: str = Field( ..., description="The ID of the expense category." ) class GetExpenseCategoryLabels(BaseModel): """Schema for the `get_expense_category_labels` operation.""" category_id: str = Field( ..., description="The ID of the expense category." ) page: Optional[int] = Field( 0, description="Pagination page number, default is 0." ) per_page: Optional[int] = Field( 10, description="Number of labels per page, default is 10." ) active: Optional[bool] = Field( None, description="Filter labels by active status." ) search: Optional[str] = Field( None, description="Search term to filter labels." ) sort_field: Optional[str] = Field( None, description="Field to sort labels by." ) sort_direction: Optional[str] = Field( None, description="Direction to sort the labels (ASC or DESC)." ) class CreateExpenseCategory(BaseModel): """Schema for the `create_expense_category` operation.""" name: str = Field( ..., description="The name of the expense category." ) code: str = Field( ..., description="A unique code for the expense category." ) required: bool = Field( ..., description="Indicates whether the expense category is required." ) active: Optional[bool] = Field( None, description="The active status of the category." ) free_text_allowed: Optional[bool] = Field( None, description="Indicates if free text is allowed." ) class CreateExpenseCategoryLabel(BaseModel): """Schema for the `create_expense_category_label` operation.""" category_id: str = Field( ..., description="The ID of the expense category." ) name: str = Field( ..., description="The name of the expense category label." ) code: str = Field( ..., description="A unique code for the expense category label." ) active: bool = Field( True, description="The active status of the label (defaults to True)." ) class UpdateExpenseCategory(BaseModel): """Schema for the `update_expense_category` operation.""" category_id: str = Field( ..., description="The ID of the expense category to update." ) name: Optional[str] = Field( None, description="The new name for the expense category." ) active: Optional[bool] = Field( None, description="The updated active status." ) required: Optional[bool] = Field( None, description="The updated required status." ) free_text_allowed: Optional[bool] = Field( None, description="Indicates if free text is allowed." ) class UpdateExpenseCategoryLabel(BaseModel): """Schema for the `update_expense_category_label` operation.""" category_id: str = Field( ..., description="The ID of the expense category." ) label_id: str = Field( ..., description="The ID of the expense category label to update." ) name: Optional[str] = Field( None, description="The new name for the label." ) active: Optional[bool] = Field( None, description="The updated active status of the label." ) class CreateReceiptAttachmentSchema(BaseModel): """Schema for the `create_receipt_attachment` operation.""" file_path: str = Field( ..., description="File path for the receipt attachment to be uploaded via multipart form data." ) transaction_id: Optional[str] = Field( ..., description="The optional unique identifier of the transaction to attach the receipt to." ) class AutomatchReceiptsSchema(BaseModel): """Schema for the `automatch_receipts` operation.""" receipt_attachment_ids: List[str] = Field( ..., description="A list of receipt attachment IDs to be automatched." ) class GetAutomatchStatusSchema(BaseModel): """Schema for the `get_automatch_status` operation.""" job_id: str = Field( ..., description="The unique identifier of the automatch job whose status is to be retrieved." ) class SendReceiptReminderSchema(BaseModel): """Schema for the `send_receipt_reminder` operation.""" transaction_id: str = Field( ..., description="The unique identifier of the transaction to send a receipt reminder for." ) ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/prompts.py: -------------------------------------------------------------------------------- ```python get_virtual_cards_prompt = """ This tool will retrieve all of the user's virtual cards from Extend. It takes the following arguments: - page (int): The page number for the paginated list of virtual cards. - per_page (int): The number of virtual cards per page. - status (Optional[str]): Filter virtual cards by status (e.g., ACTIVE, CANCELLED, PENDING, EXPIRED, CLOSED, CONSUMED). - recipient (Optional[str]): Filter by the recipient identifier. - search_term (Optional[str]): A search term to filter the virtual cards. - sort_field (Optional[str]): Field to sort by (e.g., 'createdAt', 'updatedAt', 'balanceCents', 'displayName', 'type', or 'status'). - sort_direction (Optional[str]): Sort direction, either ASC or DESC. USE "DESC" FOR MOST RECENT OR HIGHEST VALUES FIRST. IMPORTANT USAGE GUIDELINES: 1. When retrieving recently created cards, ALWAYS set sort_field="createdAt" and sort_direction="DESC". 2. Use status filters whenever possible to narrow results (e.g., status="ACTIVE" for only active cards). 3. For specific cards, use search_term to reduce the result set size. The response includes the fetched virtual cards as well pagination metadata. """ get_virtual_card_detail_prompt = """ This tool retrieves detailed information for a specific virtual card from Extend. It takes the following argument: - virtual_card_id (str): The ID of the virtual card. The response contains all details of the virtual card. """ cancel_virtual_card_prompt = """ This tool cancels a virtual card in Extend. It takes the following argument: - virtual_card_id (str): The ID of the virtual card to cancel. """ close_virtual_card_prompt = """ This tool closes a virtual card in Extend. It takes the following argument: - virtual_card_id (str): The ID of the virtual card to close. """ get_credit_cards_prompt = """ This tool retrieves a list of credit cards from Extend. It takes the following arguments: - page (int): The page number for the paginated list. - per_page (int): The number of credit cards per page. - status (Optional[str]): Filter credit cards by status. - search_term (Optional[str]): A search term to filter credit cards. - sort_direction (Optional[str]): Sort direction (ASC or DESC). The response includes fetched credit cards and pagination metadata. """ get_credit_card_detail_prompt = """ This tool retrieves detailed information for a specific credit card in Extend. It takes the following argument: - credit_card_id (str): The ID of the credit card. The response includes the credit card's detailed information. """ get_transactions_prompt = """ This tool retrieves a list of transactions from Extend. It takes the following arguments: - page (int): The page number for the paginated list. - per_page (int): The number of transactions per page. - from_date (Optional[str]): Filter transactions starting from this date (YYYY-MM-DD). - to_date (Optional[str]): Filter transactions up to this date (YYYY-MM-DD). - status (Optional[str]): Filter transactions by status (e.g., PENDING, CLEARED, DECLINED, etc.). - virtual_card_id (Optional[str]): Filter by a specific virtual card ID. - min_amount_cents (Optional[int]): Minimum transaction amount in cents. - max_amount_cents (Optional[int]): Maximum transaction amount in cents. - receipt_missing (Optional[bool]): Filter transactions by whether they are missing a receipt - search_term (Optional[str]): A search term to filter transactions. - sort_field (Optional[str]): Field to sort by, with optional direction Use "recipientName", "merchantName", "amount", "date" for ASC Use "-recipientName", "-merchantName", "-amount", "-date" for DESC IMPORTANT USAGE GUIDELINES: 1. When retrieving most recent transactions, ALWAYS use sort_field="-date" (negative prefix indicates descending order). 2. Use filters (from_date, to_date, status) whenever possible to reduce result set size. 3. For large result sets, use pagination appropriately with reasonable per_page values. 4. Note that sort direction is specified as part of the sort_field parameter: - For DESCENDING order (newest to oldest, highest to lowest), prefix the sort_field value with "-" (e.g., "-date") - For ASCENDING order (oldest to newest, lowest to highest), use the sort_field value without a prefix (e.g., "date") There is no separate sort_direction parameter. The response is a JSON object with a "reports" key containing: - "transactions": An array of transaction objects - "page": The current page number - "pageItemCount": Number of items per page - "totalItems": Total number of transactions matching the query - "numberOfPages": Total number of pages available """ get_transaction_detail_prompt = """ This tool retrieves detailed information for a specific transaction in Extend. It takes the following argument: - transaction_id (str): The ID of the transaction. The response includes the transaction's detailed information. """ propose_transaction_expense_data_prompt = """ IMPORTANT: This tool does NOT immediately update expense data. It only proposes changes that require user confirmation. This tool will propose expense data changes for a specific transaction in Extend. It takes the following arguments: - transaction_id (str): The unique identifier of the transaction. - data (Dict): A dictionary representing the expense data to propose. Expected format: { "expenseDetails": [{"categoryId": str, "labelId": str}] } The response is a JSON object with proposal details including a confirmation token. After calling this tool, you MUST present the confirmation details to the user and explicitly ask them to confirm before proceeding. """ confirm_transaction_expense_data_prompt = """ IMPORTANT: This tool finalizes expense data changes that were previously proposed. It takes the following argument: - confirmation_token (str): The unique token from the proposal step that was provided to the user. DO NOT attempt to use this tool unless the user has explicitly provided the confirmation token. The response is a JSON object with the updated transaction details. """ update_transaction_expense_data_prompt = """ IMPORTANT: NEVER use this tool without confirming with the user which expense category and label to use. If the user has not specified a category and label, you must ask them for their selection before proceeding. Only proceed with the update after receiving explicit confirmation from the user. IMPORTANT: TRANSACTIONS of any status can be updated. Step 1: If the user has not specified an expense category and label, present user with all of the the available categories and ask them to select one Step 2: Once the user has confirmed the expense category, then present them with the list of labels for that expense category Step 3: Only proceed with the update after receiving the users explicit confirmation This tool updates the expense data for a specific transaction in Extend. It takes the following arguments: - transaction_id (str): The unique identifier of the transaction. - user_confirmed_data_values (bool): Must be True if the user has confirmed the expense category and label values. - data (Dict): A dictionary representing the expense data to update. Expected format: { "expenseDetails": [{"categoryId": str, "labelId": str}] } The response includes the updated transaction's details. """ get_expense_categories_prompt = """ This tool retrieves a list of expense categories from Extend. It takes the following optional arguments: - active (Optional[bool]): Filter categories by their active status. - required (Optional[bool]): Filter categories by whether they are required. - search (Optional[str]): A search term to filter categories. - sort_field (Optional[str]): Field to sort the categories by (e.g., "name", "code", "createdAt"). - sort_direction (Optional[str]): Direction to sort ("ASC" for ascending or "DESC" for descending). IMPORTANT USAGE GUIDELINES: 1. When retrieving categories in alphabetical order, use sort_field="name" and sort_direction="ASC". 2. When retrieving most recently created categories, use sort_field="createdAt" and sort_direction="DESC". 3. Use the active=true filter to retrieve only currently active categories. 4. When looking for a specific category, use the search parameter to narrow results. The response includes the fetched list of expense categories. """ get_expense_category_prompt = """ This tool retrieves detailed information for a specific expense category from Extend. It takes the following argument: - category_id (str): The ID of the expense category. The response includes the expense category details. """ get_expense_category_labels_prompt = """ This tool retrieves a paginated list of labels for a specific expense category in Extend. It takes the following arguments: - category_id (str): The ID of the expense category. - page (Optional[int]): The page number for pagination (default is 0). - per_page (Optional[int]): The number of labels per page (default is 10). - active (Optional[bool]): Filter labels by their active status. - search (Optional[str]): A search term to filter labels. - sort_field (Optional[str]): Field to sort the labels by (e.g., "name", "code", "createdAt"). - sort_direction (Optional[str]): Direction to sort ("ASC" for ascending or "DESC" for descending). IMPORTANT USAGE GUIDELINES: 1. The category_id parameter is required and must be valid. 2. When retrieving labels in alphabetical order, use sort_field="name" and sort_direction="ASC". 3. When retrieving most recently created labels, use sort_field="createdAt" and sort_direction="DESC". 4. Use the active=true filter to retrieve only currently active labels. 5. For retrieving all labels, increase per_page parameter to an appropriate value. 6. When looking for a specific label, use the search parameter to narrow results. The response includes the fetched expense category labels and pagination metadata. """ create_expense_category_prompt = """ This tool creates a new expense category in Extend. It takes the following arguments: - name (str): The name of the expense category. - code (str): A unique code for the expense category. - required (bool): Indicates whether the expense category is required. - active (Optional[bool]): The active status of the category. - free_text_allowed (Optional[bool]): Indicates if free text is allowed. The response includes the created expense category details. """ create_expense_category_label_prompt = """ This tool creates a new expense category label in Extend. It takes the following arguments: - category_id (str): The ID of the expense category. - name (str): The name of the expense category label. - code (str): A unique code for the expense category label. - active (bool): The active status of the label (defaults to True). The response includes the created expense category label details. """ update_expense_category_prompt = """ This tool updates an existing expense category in Extend. It takes the following arguments: - category_id (str): The ID of the expense category to update. Optional arguments include: - name (Optional[str]): The new name for the expense category. - active (Optional[bool]): The updated active status. - required (Optional[bool]): The updated required status. - free_text_allowed (Optional[bool]): Indicates if free text is allowed. The response includes the updated expense category details. """ update_expense_category_label_prompt = """ This tool updates an existing expense category label in Extend. It takes the following arguments: - category_id (str): The ID of the expense category. - label_id (str): The ID of the expense category label to update. Optional arguments include: - name (Optional[str]): The new name for the label. - active (Optional[bool]): The updated active status of the label. The response includes the updated expense category label details. """ create_receipt_attachment_prompt = """ IMPORTANT: This does not require a transaction id to be passed in. Do not use one if the user does not specify a transaction id. This tool creates a receipt attachment in Extend by uploading a file via multipart form data. It takes the following arguments: - file_path (str): The file path for the receipt attachment image (the file should be accessible and in a supported format, e.g., PNG, JPEG, GIF, BMP, TIFF, HEIC, or PDF). - transaction_id (Optional[str]): The optional unique identifier of the transaction to attach the receipt to. The response is a JSON object containing details of the receipt attachment, including: - id: The unique identifier of the attachment. - contentType: The MIME type (e.g., 'image/png'). - urls: A dictionary with URLs for the original image, main image, and thumbnail. - createdAt and updatedAt timestamps. """ automatch_receipts_prompt = """ This tool initiates an asynchronous bulk receipt automatch job in Extend. It takes the following argument: - receipt_attachment_ids (List[str]): A list of receipt attachment IDs to be automatched. The response is a JSON object containing details of the automatch job, including: - id: The unique bulk job ID. - tasks: A list of task objects, each including the task ID, status, associated receipt attachment ID, matched transaction ID (if available), and the count of attachments. """ get_automatch_status_prompt = """ This tool retrieves the status of a bulk receipt automatch job in Extend. It takes the following argument: - job_id (str): The unique identifier of the automatch job whose status is to be retrieved. The response is a JSON object providing the current status and details of the job, including: - id: The job ID. - tasks: A list of task objects detailing each automatch operation, such as task ID, status, receipt attachment ID, matched transaction ID (if available), and attachments count. """ send_receipt_reminder_prompt = """ This tool sends a receipt reminder for a specific transaction in Extend. It takes the following argument: - transaction_id (str): The unique identifier of the transaction to send a receipt reminder for. The response is a 200 status code indicating that the reminder was sent successfully. If you receive a 429 response, it indicates that the user has already received a reminder for this transaction and only one can be sent out every 24 hours. This is useful for following up on missing receipts or encouraging users to submit receipts for transactions that require them. """ ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/tests/test_integration.py: -------------------------------------------------------------------------------- ```python import asyncio import os import tempfile import uuid import httpx import pytest from dotenv import load_dotenv from extend import ExtendClient from extend_ai_toolkit.shared.functions import ( get_virtual_cards, get_credit_cards, create_expense_category, create_expense_category_label, get_expense_category_labels, update_expense_category_label, get_expense_categories, get_expense_category, update_expense_category, get_transactions, update_transaction_expense_data, create_receipt_attachment, automatch_receipts, send_receipt_reminder ) load_dotenv() @pytest.fixture(scope="session") def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() # Skip all tests if environment variables are not set pytestmark = pytest.mark.skipif( not all([ os.environ.get("EXTEND_API_KEY"), os.environ.get("EXTEND_API_SECRET"), os.environ.get("EXTEND_TEST_RECIPIENT"), os.environ.get("EXTEND_TEST_CARDHOLDER") ]), reason="Integration tests require EXTEND_API_KEY, EXTEND_API_SECRET, EXTEND_TEST_RECIPIENT, and EXTEND_TEST_CARDHOLDER environment variables" ) @pytest.fixture(scope="session") def extend(): """Create a real API client for integration testing""" api_key = os.environ.get("EXTEND_API_KEY") api_secret = os.environ.get("EXTEND_API_SECRET") return ExtendClient(api_key, api_secret) @pytest.fixture(scope="session") def test_recipient(): """Get the test recipient email""" return os.environ.get("EXTEND_TEST_RECIPIENT") @pytest.fixture(scope="session") def test_cardholder(): """Get the test cardholder email""" return os.environ.get("EXTEND_TEST_CARDHOLDER") _cached_test_credit_card = None @pytest.fixture(scope="session") def test_credit_card(extend, event_loop): """Synchronous fixture that caches the first active credit card for testing.""" global _cached_test_credit_card if _cached_test_credit_card is None: response = event_loop.run_until_complete( get_credit_cards(extend=extend, status="ACTIVE") ) assert response.get("creditCards"), "No credit cards available for testing" _cached_test_credit_card = response["creditCards"][0] return _cached_test_credit_card @pytest.mark.integration class TestCreditCards: """Integration tests for credit card functions""" @pytest.mark.asyncio async def test_list_credit_cards(self, extend): """Test listing credit cards with various filters""" response = await get_credit_cards(extend=extend) assert "creditCards" in response # List with pagination response = await get_credit_cards( extend=extend, page=1, per_page=10 ) assert len(response["creditCards"]) <= 10 # List with status filter response = await get_credit_cards( extend=extend, status="ACTIVE" ) for card in response["creditCards"]: assert card["status"] == "ACTIVE" @pytest.mark.integration class TestVirtualCards: """Integration tests for virtual card operations""" @pytest.mark.asyncio async def test_list_virtual_cards(self, extend): """Test listing virtual cards with various filters""" response = await get_virtual_cards(extend) assert "virtualCards" in response # List with pagination response = await get_virtual_cards( extend=extend, page=1, per_page=10 ) assert len(response["virtualCards"]) <= 10 # List with status filter response = await get_virtual_cards( extend=extend, status="CLOSED" ) for card in response["virtualCards"]: assert card["status"] == "CLOSED" @pytest.mark.integration class TestTransactions: """Integration tests for transaction operations""" @pytest.mark.asyncio async def test_list_transactions(self, extend): """Test listing transactions with various filters""" # Get transactions response = await get_transactions(extend) # Verify response structure assert isinstance(response, dict), "Response should be a dictionary" assert "report" in response, "Response should contain 'report' key" assert "transactions" in response["report"], "Report should contain 'transactions' key" assert isinstance(response["report"]["transactions"], list), "Transactions should be a list" @pytest.mark.asyncio async def test_update_transaction_expense_data(self, extend): """Test updating transaction expense data""" # Get a single transaction transactions_response = await get_transactions(extend, page=0, per_page=1, sort_field="createdAt") assert "transactions" in transactions_response["report"] transaction = transactions_response["report"]["transactions"][0] transaction_id = transaction["id"] # Update the transaction to have no expense categories data_no_expense_categories = { "expenseDetails": [] } response_no_expense_categories = await update_transaction_expense_data( extend, transaction_id, user_confirmed_data_values=True, data=data_no_expense_categories ) assert isinstance(response_no_expense_categories, dict), "Response should be a dictionary" assert response_no_expense_categories["id"] == transaction_id, "Transaction ID should match the input" assert "expenseCategories" not in response_no_expense_categories, "Expense categories should not exist on response" # Get an expense category and one of its labels expense_categories_response = await get_expense_categories(extend=extend, active=True) assert "expenseCategories" in expense_categories_response, "No expense categories found" expense_category = expense_categories_response["expenseCategories"][0] category_id = expense_category["id"] expense_category_labels_response = await get_expense_category_labels(extend, category_id=category_id) if not expense_category_labels_response.get("expenseLabels"): # Create a new label if none exist label_name = f"Test Label {str(uuid.uuid4())[:8]}" label_code = f"LBL{str(uuid.uuid4())[:8]}" expense_label_response = await create_expense_category_label( extend=extend, category_id=category_id, name=label_name, code=label_code, active=True ) expense_label = expense_label_response else: expense_label = expense_category_labels_response["expenseLabels"][0] # Update the transaction with an expense category and one of its labels data_with_expense_category = { "expenseDetails": [ { "categoryId": expense_category["id"], "labelId": expense_label["id"] } ] } response_with_expense_category = await update_transaction_expense_data( extend=extend, transaction_id=transaction_id, user_confirmed_data_values=True, data=data_with_expense_category ) assert isinstance(response_with_expense_category, dict), "Response should be a dictionary" assert response_with_expense_category["id"] == transaction_id, "Transaction ID should match the input" assert len(response_with_expense_category[ "expenseCategories"]) == 1, "Transaction should have only one expense category coding" coded_expense_category = response_with_expense_category["expenseCategories"][0] assert coded_expense_category["categoryId"] == expense_category[ "id"], "Expense categories should match the input" assert coded_expense_category["labelId"] == expense_label["id"], "Expense categories should match the input" @pytest.mark.integration class TestExpenseData: """Integration tests for expense category and label endpoints""" @pytest.mark.asyncio async def test_list_expense_categories(self, extend): """Test getting a list of expense categories""" response = await get_expense_categories(extend=extend) # Adjust the key based on your API's response structure assert isinstance(response, dict) # Example: if your response contains a key "expenseCategories" assert "expenseCategories" in response or "categories" in response @pytest.mark.asyncio async def test_create_and_get_expense_category(self, extend): """Test creating an expense category and then retrieving it""" # Create a new expense category with unique values category_name = f"Integration Test Category {str(uuid.uuid4())[:8]}" category_code = f"ITC{str(uuid.uuid4())[:8]}" create_response = await create_expense_category( extend=extend, name=category_name, code=category_code, required=True, active=True, free_text_allowed=False, ) category = create_response assert category, "Expense category creation failed" category_id = category["id"] # Retrieve the created category get_response = await get_expense_category(extend=extend, category_id=category_id) retrieved_category = get_response assert retrieved_category, "Expense category retrieval failed" assert retrieved_category["id"] == category_id @pytest.mark.asyncio async def test_update_expense_category(self, extend): """Test updating an expense category""" category_name = f"Integration Test Category {str(uuid.uuid4())[:8]}" category_code = f"ITC{str(uuid.uuid4())[:8]}" create_response = await create_expense_category( extend=extend, name=category_name, code=category_code, required=False, active=True, free_text_allowed=False, ) category = create_response category_id = category["id"] # Update the expense category new_name = f"Updated Category {str(uuid.uuid4())[:8]}" update_response = await update_expense_category( extend=extend, category_id=category_id, name=new_name, active=False, required=False, free_text_allowed=True, ) updated_category = update_response assert updated_category, "Expense category update failed" assert updated_category["name"] == new_name assert updated_category["active"] is False @pytest.mark.asyncio async def test_create_and_list_expense_category_labels(self, extend): """Test creating an expense category label and listing labels for a category""" # Create a new expense category first category_name = f"Integration Test Category {str(uuid.uuid4())[:8]}" category_code = f"ITC{str(uuid.uuid4())[:8]}" create_cat_response = await create_expense_category( extend=extend, name=category_name, code=category_code, required=True, active=True, free_text_allowed=False, ) category = create_cat_response category_id = category["id"] # Create a new label for this expense category label_name = f"Label {str(uuid.uuid4())[:8]}" label_code = f"LBL{str(uuid.uuid4())[:8]}" create_label_response = await create_expense_category_label( extend=extend, category_id=category_id, name=label_name, code=label_code, active=True ) label = create_label_response assert label, "Expense category label creation failed" label_id = label["id"] # List labels for the expense category list_labels_response = await get_expense_category_labels( extend=extend, category_id=category_id, page=0, per_page=10 ) labels = list_labels_response.get("expenseLabels") assert labels is not None, "Expense category labels not found in response" # Verify that the newly created label is present in the list assert any(l["id"] == label_id for l in labels), "Created label not found in list" @pytest.mark.asyncio async def test_update_expense_category_label(self, extend): """Test updating an expense category label""" # Create a new expense category first category_name = f"Integration Test Category {str(uuid.uuid4())[:8]}" category_code = f"ITC{str(uuid.uuid4())[:8]}" create_cat_response = await create_expense_category( extend=extend, name=category_name, code=category_code, required=True, active=True, free_text_allowed=False, ) category = create_cat_response category_id = category["id"] # Create a new label for this category label_name = f"Label {str(uuid.uuid4())[:8]}" label_code = f"LBL{str(uuid.uuid4())[:8]}" create_label_response = await create_expense_category_label( extend=extend, category_id=category_id, name=label_name, code=label_code, active=True ) label = create_label_response label_id = label["id"] # Update the expense category label new_label_name = f"Updated Label {str(uuid.uuid4())[:8]}" update_label_response = await update_expense_category_label( extend=extend, category_id=category_id, label_id=label_id, name=new_label_name ) updated_label = update_label_response assert updated_label, "Expense category label update failed" assert updated_label["name"] == new_label_name @pytest.mark.integration class TestReceiptAttachments: """Integration tests for receipt attachment operations""" @pytest.mark.asyncio async def test_create_receipt_attachment(self, extend): """ Test creating a receipt attachment for a transaction by uploading a file. """ # Retrieve a transaction to attach the receipt to transactions_response = await get_transactions(extend, page=0, per_page=1) if not transactions_response.get("report", {}).get("transactions"): pytest.skip("No transactions available to attach receipt to") transaction_id = transactions_response["report"]["transactions"][0]["id"] # Create a temporary PNG file with minimal valid header bytes with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: # Write minimal PNG header bytes (this is not a complete image, # but may be sufficient for testing file upload endpoints) tmp.write(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01' b'\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89') tmp.flush() tmp_name = tmp.name try: response = await create_receipt_attachment( extend=extend, transaction_id=transaction_id, file_path=tmp_name ) # Verify that the response contains expected receipt attachment fields. # Adjust these assertions based on your API response structure. assert response is not None, "Response should not be None" # Check for common fields in a successful response for field in ["id", "transactionId", "contentType", "urls", "createdAt", "uploadType"]: assert field in response, f"Missing expected field: {field}" # Initiate an automatch job automatch_response = await automatch_receipts( extend=extend, receipt_attachment_ids=[response["id"]] ) assert "id" in automatch_response, "Automatch response should include a job id" assert "tasks" in automatch_response, "Automatch response should include tasks" job_id = automatch_response["id"] # Retrieve the automatch job status using the new endpoint status_response = await extend.receipt_capture.get_automatch_status(job_id) assert "id" in status_response, "Status response should include a job id" assert status_response["id"] == job_id, "Job id should match the one returned during automatch" assert "tasks" in status_response, "Status response should include tasks" finally: # Clean up the temporary file os.remove(tmp_name) @pytest.mark.asyncio async def test_send_receipt_reminder(self, extend): """ Test sending a receipt reminder for a transaction. """ # Retrieve a transaction to send a reminder for transactions_response = await get_transactions(extend, page=0, per_page=10, receipt_missing=True) if not transactions_response.get("report", {}).get("transactions"): pytest.skip("No transactions available to send receipt reminder for") transaction_id = transactions_response["report"]["transactions"][0]["id"] try: # Send receipt reminder result = await send_receipt_reminder(extend, transaction_id) # The call should succeed and return None assert result is None except Exception as exc: # With exception chaining, check the cause. original_error = exc.__cause__ assert original_error is not None, "Expected a chained exception" if isinstance(original_error, httpx.HTTPStatusError): assert original_error.response.status_code == 429 else: raise AssertionError("Expected httpx.HTTPStatusError as the cause") from exc def test_environment_variables(): """Test that required environment variables are set""" assert os.environ.get("EXTEND_API_KEY"), "EXTEND_API_KEY environment variable is required" assert os.environ.get("EXTEND_API_SECRET"), "EXTEND_API_SECRET environment variable is required" ``` -------------------------------------------------------------------------------- /extend_ai_toolkit/shared/functions.py: -------------------------------------------------------------------------------- ```python import io import logging import os import uuid from datetime import datetime, timedelta from typing import Dict, Optional, List from extend import ExtendClient logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) pending_selections = {} # ========================= # Virtual Card Functions # ========================= async def get_virtual_cards( extend: ExtendClient, page: int = 0, per_page: int = 10, status: Optional[str] = None, recipient: Optional[str] = None, search_term: Optional[str] = None, sort_field: Optional[str] = None, sort_direction: Optional[str] = None, ) -> Dict: """Get list of virtual cards Args: page (int): The page number for pagination. Defaults to 0. per_page (int): The number of virtual cards to return per page. Defaults to 10. status (Optional[str]): Filter cards by status (e.g., "ACTIVE", "CANCELLED", "PENDING", "EXPIRED", "CLOSED", "CONSUMED") recipient (Optional[str], optional): A filter by recipient identifier. Defaults to None. search_term (Optional[str], optional): A term to search virtual cards by. Defaults to None. sort_field (Optional[str]): Field to sort by "createdAt", "updatedAt", "balanceCents", "displayName", "type", or "status" sort_direction (Optional[str]): Direction to sort (ASC or DESC) """ try: response = await extend.virtual_cards.get_virtual_cards( page=page, per_page=per_page, status=status.upper() if status else None, recipient=recipient, search_term=search_term, sort_field=sort_field, sort_direction=sort_direction ) return response except Exception as e: logger.error("Error getting virtual cards: %s", e) raise Exception("Error getting virtual cards: %s", e) async def get_virtual_card_detail(extend: ExtendClient, virtual_card_id: str) -> Dict: """Get details of a specific virtual card""" try: response = await extend.virtual_cards.get_virtual_card_detail(virtual_card_id) return response except Exception as e: logger.error("Error getting virtual card detail: %s", e) raise Exception(e) async def close_virtual_card(extend: ExtendClient, virtual_card_id: str) -> Dict: """Close a specific virtual card""" try: response = await extend.virtual_cards.close_virtual_card(virtual_card_id) return response except Exception as e: logger.error("Error closing virtual card: %s", e) raise Exception("Error closing virtual card") async def cancel_virtual_card(extend: ExtendClient, virtual_card_id: str) -> Dict: """Cancel a specific virtual card""" try: response = await extend.virtual_cards.cancel_virtual_card(virtual_card_id) return response except Exception as e: logger.error("Error canceling virtual card: %s", e) raise Exception("Error canceling virtual card") # ========================= # Transaction Functions # ========================= async def get_transactions( extend: ExtendClient, page: int = 0, per_page: int = 50, from_date: Optional[str] = None, to_date: Optional[str] = None, status: Optional[str] = None, virtual_card_id: Optional[str] = None, min_amount_cents: Optional[int] = None, max_amount_cents: Optional[int] = None, receipt_missing: Optional[bool] = None, search_term: Optional[str] = None, sort_field: Optional[str] = None, ) -> Dict: """ Get a list of recent transactions Args: page (int): pagination page number, per_page (int): number of transactions per page, from_date (Optional[str]): Start date (YYYY-MM-DD) to_date (Optional[str]): End date (YYYY-MM-DD) status (Optional[str]): Filter transactions by status (e.g., "PENDING", "CLEARED", "DECLINED", "NO_MATCH", "AVS_PASS", "AVS_FAIL", "AUTH_REVERSAL") virtual_card_id (Optional[str]): Filter by specific virtual card min_amount_cents (Optional[int]): Minimum amount in cents max_amount_cents (Optional[int]): Maximum amount in cents receipt_missing (Optional[bool]): Filter transactions by whether they are missing a receipt search_term (Optional[str]): Filter transactions by search term (e.g., "Subscription") sort_field (Optional[str]): Field to sort by, with optional direction Use "recipientName", "merchantName", "amount", "date" for ASC Use "-recipientName", "-merchantName", "-amount", "-date" for DESC """ try: response = await extend.transactions.get_transactions( page=page, per_page=per_page, from_date=from_date, to_date=to_date, status=status.upper() if status else None, virtual_card_id=virtual_card_id, min_amount_cents=min_amount_cents, max_amount_cents=max_amount_cents, search_term=search_term, sort_field=sort_field, receipt_missing=receipt_missing, ) return response except Exception as e: logger.error("Error getting transactions: %s", e) raise Exception("Error getting transactions") async def get_transaction_detail(extend: ExtendClient, transaction_id: str) -> Dict: """Get a transaction detail""" try: response = await extend.transactions.get_transaction(transaction_id) return response except Exception as e: logger.error("Error getting transaction detail: %s", e) raise Exception("Error getting transaction detail") # ========================= # Credit Card Functions # ========================= async def get_credit_cards( extend: ExtendClient, page: int = 0, per_page: int = 10, status: Optional[str] = None, search_term: Optional[str] = None, sort_direction: Optional[str] = None, ) -> Dict: """Get a list of credit cards""" try: response = await extend.credit_cards.get_credit_cards( page=page, per_page=per_page, status=status.upper() if status else None, search_term=search_term, sort_direction=sort_direction, ) return response except Exception as e: logger.error("Error getting credit cards: %s", e) raise Exception("Error getting credit cards") async def get_credit_card_detail(extend: ExtendClient, credit_card_id: str) -> Dict: """Get details of a specific credit card""" try: response = await extend.virtual_cards.get_credit_card_detail(credit_card_id) return response except Exception as e: logger.error("Error getting credit card details: %s", e) raise Exception(e) # ========================= # Expense Data Functions # ========================= async def get_expense_categories( extend: ExtendClient, active: Optional[bool] = None, required: Optional[bool] = None, search: Optional[str] = None, sort_field: Optional[str] = None, sort_direction: Optional[str] = None, ) -> Dict: """ Get a list of expense categories. """ try: response = await extend.expense_data.get_expense_categories( active=active, required=required, search=search, sort_field=sort_field, sort_direction=sort_direction, ) return response except Exception as e: logger.error("Error getting expense categories: %s", e) raise Exception("Error getting expense categories: %s", e) async def get_expense_category(extend: ExtendClient, category_id: str) -> Dict: """ Get detailed information about a specific expense category. """ try: response = await extend.expense_data.get_expense_category(category_id) return response except Exception as e: logger.error("Error getting expense category: %s", e) raise Exception("Error getting expense category: %s", e) async def get_expense_category_labels( extend: ExtendClient, category_id: str, page: Optional[int] = None, per_page: Optional[int] = None, active: Optional[bool] = None, search: Optional[str] = None, sort_field: Optional[str] = None, sort_direction: Optional[str] = None, ) -> Dict: """ Get a paginated list of expense category labels. """ try: response = await extend.expense_data.get_expense_category_labels( category_id=category_id, page=page, per_page=per_page, active=active, search=search, sort_field=sort_field, sort_direction=sort_direction, ) return response except Exception as e: logger.error("Error getting expense category labels: %s", e) raise Exception("Error getting expense category labels: %s", e) async def create_expense_category( extend: ExtendClient, name: str, code: str, required: bool, active: Optional[bool] = None, free_text_allowed: Optional[bool] = None, ) -> Dict: """ Create an expense category. """ try: response = await extend.expense_data.create_expense_category( name=name, code=code, required=required, active=active, free_text_allowed=free_text_allowed, ) return response except Exception as e: logger.error("Error creating expense category: %s", e) raise Exception("Error creating expense category: %s", e) async def create_expense_category_label( extend: ExtendClient, category_id: str, name: str, code: str, active: bool = True ) -> Dict: """ Create an expense category label. """ try: response = await extend.expense_data.create_expense_category_label( category_id=category_id, name=name, code=code, active=active ) return response except Exception as e: logger.error("Error creating expense category label: %s", e) raise Exception("Error creating expense category label: %s", e) async def update_expense_category( extend: ExtendClient, category_id: str, name: Optional[str] = None, active: Optional[bool] = None, required: Optional[bool] = None, free_text_allowed: Optional[bool] = None, ) -> Dict: """ Update an expense category. """ try: response = await extend.expense_data.update_expense_category( category_id=category_id, name=name, active=active, required=required, free_text_allowed=free_text_allowed, ) return response except Exception as e: logger.error("Error updating expense category: %s", e) raise Exception("Error updating expense category: %s", e) async def update_expense_category_label( extend: ExtendClient, category_id: str, label_id: str, name: Optional[str] = None, active: Optional[bool] = None, ) -> Dict: """ Update an expense category label. """ try: response = await extend.expense_data.update_expense_category_label( category_id=category_id, label_id=label_id, name=name, active=active, ) return response except Exception as e: logger.error("Error updating expense category label: %s", e) raise Exception("Error updating expense category label: %s", e) async def propose_transaction_expense_data( extend: ExtendClient, transaction_id: str, data: Dict ) -> Dict: """ Propose expense data changes for a transaction without applying them. Args: extend: The Extend client instance transaction_id: The unique identifier of the transaction data: A dictionary representing the expense data to update Returns: Dict: A confirmation request with token and expiration """ # Fetch transaction to ensure it exists transaction = await extend.transactions.get_transaction(transaction_id) # Generate a unique confirmation token confirmation_token = str(uuid.uuid4()) # Set expiration time (10 minutes from now) expiration_time = datetime.now() + timedelta(minutes=10) # Store the pending selection with its metadata pending_selections[confirmation_token] = { "transaction_id": transaction_id, "data": data, "created_at": datetime.now().isoformat(), "expires_at": expiration_time.isoformat(), "status": "pending" } # Return the confirmation request return { "status": "pending_confirmation", "transaction_id": transaction_id, "confirmation_token": confirmation_token, "expires_at": expiration_time.isoformat(), "proposed_categories": [ {"categoryId": category.get("categoryId", "Unknown"), "labelId": category.get("labelId", "None")} for category in data.get("expenseDetails", []) ] } async def confirm_transaction_expense_data( extend: ExtendClient, confirmation_token: str ) -> Dict: """ Confirm and apply previously proposed expense data changes. Args: extend: The Extend client instance confirmation_token: The unique token from the proposal step Returns: Dict: The updated transaction details """ # Check if token exists if confirmation_token not in pending_selections: raise Exception("Invalid confirmation token") # Get the pending selection selection = pending_selections[confirmation_token] # Check if expired if datetime.now() > datetime.fromisoformat(selection["expires_at"]): # Clean up expired token del pending_selections[confirmation_token] raise Exception("Confirmation token has expired") # Apply the expense data update response = await extend.transactions.update_transaction_expense_data( selection["transaction_id"], selection["data"] ) # Mark as confirmed and clean up selection["status"] = "confirmed" selection["confirmed_at"] = datetime.now().isoformat() # In a real implementation, you might want to keep the record for auditing # but for simplicity, we'll delete it here del pending_selections[confirmation_token] return response async def update_transaction_expense_data( extend: ExtendClient, transaction_id: str, user_confirmed_data_values: bool, data: Dict ) -> Dict: """ Internal function to update the expense data for a specific transaction. This should not be exposed directly to external callers. Args: extend: The Extend client instance transaction_id: The unique identifier of the transaction user_confirmed_data_values: Only true if the user has confirmed the specific values in the data argument data: A dictionary representing the expense data to update Returns: Dict: A dictionary containing the updated transaction details """ try: if not user_confirmed_data_values: raise Exception(f"User has not confirmed the expense category or label values") response = await extend.transactions.update_transaction_expense_data(transaction_id, data) return response except Exception as e: raise Exception(f"Error updating transaction expense data: {str(e)}") # ========================= # Receipt Attachment Functions # ========================= async def create_receipt_attachment( extend: ExtendClient, transaction_id: str, file_path: str, ) -> Dict: """ Create a receipt attachment by uploading a file via multipart form data. Args: extend: The Extend client instance transaction_id (str): The unique identifier of the transaction to attach the receipt to. file_path (str): A file path for the receipt image. Returns: Dict: A dictionary representing the receipt attachment details, including: - id: Unique identifier of the receipt attachment. - transactionId: The associated transaction ID. - contentType: The MIME type of the uploaded file. - urls: A dictionary with URLs for the original image, main image, and thumbnail. - createdAt: Timestamp when the receipt attachment was created. - uploadType: A string describing the type of upload (e.g., "TRANSACTION", "VIRTUAL_CARD"). """ try: with open(file_path, 'rb') as f: file_content = f.read() file_obj = io.BytesIO(file_content) # Get the filename and determine the MIME type filename = os.path.basename(file_path) mime_type = None # Set the MIME type based on file extension if filename.lower().endswith('.png'): mime_type = 'image/png' elif filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'): mime_type = 'image/jpeg' elif filename.lower().endswith('.gif'): mime_type = 'image/gif' elif filename.lower().endswith('.bmp'): mime_type = 'image/bmp' elif filename.lower().endswith('.tif') or filename.lower().endswith('.tiff'): mime_type = 'image/tiff' elif filename.lower().endswith('.heic'): mime_type = 'image/heic' elif filename.lower().endswith('.pdf'): mime_type = 'application/pdf' else: raise ValueError(f"Unsupported file type: {filename}") file_obj = io.BytesIO(file_content) file_obj.name = filename file_obj.content_type = mime_type response = await extend.receipt_attachments.create_receipt_attachment( transaction_id=transaction_id, file=file_obj ) return response except Exception as e: logger.error("Error creating receipt attachment: %s", e) raise Exception("Error creating receipt attachment: %s", e) # ========================= # Receipt Capture Functions # ========================= async def automatch_receipts( extend: ExtendClient, receipt_attachment_ids: List[str], ) -> Dict: """ Initiates an asynchronous bulk receipt automatch job. This method triggers an asynchronous job on the server that processes the provided receipt attachment IDs. The operation is non-blocking: it immediately returns a job ID and preliminary details, while the matching process is performed in the background. Args: receipt_attachment_ids (List[str]): A list of receipt attachment IDs to be automatched. Returns: Dict: A dictionary representing the Bulk Receipt Automatch Response. """ try: response = await extend.receipt_capture.automatch_receipts( receipt_attachment_ids=receipt_attachment_ids ) return response except Exception as e: logger.error("Error initiating receipt automatch: %s", e) raise Exception("Error initiating receipt automatch: %s", e) async def get_automatch_status( extend: ExtendClient, job_id: str, ) -> Dict: """ Retrieves the status of a bulk receipt capture automatch job. Args: job_id (str): The ID of the automatch job whose status is to be retrieved. Returns: Dict: A dictionary representing the current Bulk Receipt Automatch Response. """ try: response = await extend.receipt_capture.get_automatch_status(job_id=job_id) return response except Exception as e: logger.error("Error getting automatch status: %s", e) raise Exception("Error getting automatch status: %s", e) async def send_receipt_reminder( extend: ExtendClient, transaction_id: str, ) -> Dict: """ Send a transaction-specific receipt reminder. Args: extend: The Extend client instance transaction_id (str): The unique identifier of the transaction. Returns: Dict: Response from the API indicating the reminder was sent successfully. """ try: response = await extend.transactions.send_receipt_reminder(transaction_id) return response except Exception as e: logger.error("Error sending receipt reminder: %s", e) raise Exception(f"Error sending receipt reminder: {e}") from e # Optional: Cleanup function to remove expired selections async def cleanup_pending_selections(): """Remove all expired selection tokens""" now = datetime.now() expired_tokens = [ token for token, selection in pending_selections.items() if now > datetime.fromisoformat(selection["expires_at"]) ] for token in expired_tokens: del pending_selections[token] ```