This is page 1 of 3. Use http://codebase.md/ilikepizza2/qa-mcp?page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── main.py
├── mcp_server.py
├── README.md
├── requirements.txt
├── src
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── auth_agent.py
│ │ ├── crawler_agent.py
│ │ ├── js_utils
│ │ │ └── xpathgenerator.js
│ │ └── recorder_agent.py
│ ├── browser
│ │ ├── __init__.py
│ │ ├── browser_controller.py
│ │ └── panel
│ │ └── panel.py
│ ├── core
│ │ ├── __init__.py
│ │ └── task_manager.py
│ ├── dom
│ │ ├── buildDomTree.js
│ │ ├── history
│ │ │ ├── service.py
│ │ │ └── view.py
│ │ ├── service.py
│ │ └── views.py
│ ├── execution
│ │ ├── __init__.py
│ │ └── executor.py
│ ├── llm
│ │ ├── __init__.py
│ │ ├── clients
│ │ │ ├── azure_openai_client.py
│ │ │ ├── gemini_client.py
│ │ │ └── openai_client.py
│ │ └── llm_client.py
│ ├── security
│ │ ├── __init__.py
│ │ ├── nuclei_scanner.py
│ │ ├── semgrep_scanner.py
│ │ ├── utils.py
│ │ └── zap_scanner.py
│ └── utils
│ ├── __init__.py
│ ├── image_utils.py
│ └── utils.py
└── test_schema.md
```
# Files
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
```
LLM_API_KEY="YOUR_LLM_API_KEY"
LLM_BASE_URL="LLM_BASE_URL"
LLM_MODEL="LLM_MODEL"
```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
.env
__pycache__/
*.pyc
/venv/
/.venv
/output
ignore-*
.DS_Store
doc
stitched.png
visual_baselines/
/results/
```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
# VibeShift: The Security Engineer for Vibe Coders
**VibeShift** is an intelligent security agent designed to integrate seamlessly with AI coding assistants (like Cursor, GitHub Copilot, Claude Code, etc.). It acts as your automated security engineer, analyzing code generated by AI, identifying vulnerabilities, and facilitating AI-driven remediation *before* insecure code makes it to your codebase. It leverages the **MCP (Model Context Protocol)** for smooth interaction within your existing AI coding environment.
<a href="https://www.producthunt.com/posts/vibeshift-mcp?embed=true&utm_source=badge-featured&utm_medium=badge&utm_source=badge-vibeshift-mcp" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=966186&theme=light&t=1747654611925" alt="VibeShift MCP - Get secure, working code in 1 shot | Product Hunt" style="width: 115px; height: 25px;" width="250" height="54" /></a>
[](https://x.com/Omiiee_Chan)
[](https://x.com/_gauravkabra_)

**The Problem:** AI coding assistants accelerate development dramatically, but they can also generate code with subtle or overt security vulnerabilities. Manually reviewing all AI-generated code for security flaws is slow, error-prone, and doesn't scale with the speed of AI development. This "vibe-driven development" can leave applications exposed.
**The Solution: GroundNG's VibeShift** bridges this critical security gap by enabling your AI coding assistant to:
1. **Automatically Analyze AI-Generated Code:** As code is generated or modified by an AI assistant, VibeShift can be triggered to perform security analysis using a suite of tools (SAST, DAST components) and AI-driven checks.
2. **Identify Security Vulnerabilities:** Pinpoints common and complex vulnerabilities (e.g., XSS, SQLi, insecure configurations, logic flaws) within the AI-generated snippets or larger code blocks.
3. **Facilitate AI-Driven Remediation:** Provides detailed feedback and vulnerability information directly to the AI coding assistant, enabling it to suggest or even automatically apply fixes.
4. **Create a Security Feedback Loop:** Ensures that developers and their AI assistants are immediately aware of potential security risks, allowing for rapid correction and learning.
This creates a "shift-left" security paradigm for AI-assisted coding, embedding security directly into the development workflow and helping to ship more secure code, faster.
# Demo (Click to play these videos)
[](https://www.youtube.com/watch?v=bN_RgQGa8B0)
[](https://youtu.be/wCbCUCqjnXQ)
## Features
* **MCP Integration:** Seamlessly integrates with Cursor/Windsurf/Github Copilot/Roo Code
* **Automated Security Scanning:** Triggers on AI code generation/modification to perform:
* **Static Code Analysis (SAST):** Integrates tools like Semgrep to find vulnerabilities in source code.
* **Dynamic Analysis (DAST Primitives):** Can invoke tools like Nuclei or ZAP for checks against running components (where applicable).
* **AI-Assisted Test Recording:** Generate Playwright-based test scripts from natural language descriptions (in automated mode).
* **Deterministic Test Execution:** Run recorded JSON test files reliably using Playwright.
* **AI-Powered Test Discovery:** Crawl websites and leverage any LLM (in openai compliant format) to suggest test steps for discovered pages.
* **Regression Testing:** Easily run existing test suites to catch regressions.
* **Automated Feedback Loop:** Execution results (including failures, screenshots, console logs) are returned, providing direct feedback to the AI assistant.
* **Self Healing:** Existing tests self heal in case of code changes. No need to manually update.
* **UI tests:** UI tests which aren't supported by playwright directly are also supported. For example, `Check if the text is overflowing in the div`
* **Visual Regression Testing**: Using traditional pixelmatch and vision LLM approach.
## How it Works
```
+-------------+ +-----------------+ +---------------------+ +-----------------+ +-------------+
| User | ----> | AI Coding Agent | ----> | MCP Server | ----> | Scan, test, exec| ----> | Browser |
| (Developer) | | (e.g., Copilot) | | (mcp_server.py) | | (SAST, Record) | | (Playwright)|
+-------------+ +-----------------+ +---------------------+ +-----------------+ +-------------+
^ | | |
|--------------------------------------------------+----------------------------+---------------------+
[Test Results / Feedback]
```
1. **User:** Prompts their AI coding assistant (e.g., "Test this repository for security vulnerabilities", "Record a test for the login flow", "Run the regression test 'test_login.json'").
2. **AI Coding Agent:** Recognizes the intent and uses MCP to call the appropriate tool provided by the `MCP Server`.
3. **MCP Server:** Routes the request to the corresponding function (`get_security_scan`, `record_test_flow`, `run_regression_test`, `discover_test_flows`, `list_recorded_tests`).
4. **VibeShift Agent:**
* **Traditional Security Scan:** Invokes **Static Analysis Tools** (e.g., Semgrep) on the code.
* **Recording:** The `WebAgent` (in automated mode) interacts with the LLM to plan steps, controls the browser via `BrowserController` (Playwright), processes HTML/Vision, and saves the resulting test steps to a JSON file in the `output/` directory.
* **Execution:** The `TestExecutor` loads the specified JSON test file, uses `BrowserController` to interact with the browser according to the recorded steps, and captures results, screenshots, and console logs.
* **Discovery:** The `CrawlerAgent` uses `BrowserController` and `LLMClient` to crawl pages and suggest test steps.
6. **Browser:** Playwright drives the actual browser interaction.
6. **Feedback Loop:**
* The comprehensive security report (vulnerabilities, locations, suggestions) is returned through the MCP server to the **AI Coding Agent**.
* The AI Coding Agent presents this to the developer and can use the information to **suggest or apply fixes**.
* The goal is a rapid cycle of code generation -> security scan -> AI-driven fix -> re-scan (optional).
## Getting Started
### Prerequisites
* Python 3.10+
* Access to any LLM (gemini 2.0 flash works best for free in my testing)
* MCP installed (`pip install mcp[cli]`)
* Playwright browsers installed (`patchright install`)
### Installation
1. **Clone the repository:**
```bash
git clone https://github.com/GroundNG/VibeShift
cd VibeShift
```
2. **Create a virtual environment (recommended):**
```bash
python -m venv venv
source venv/bin/activate # Linux/macOS
# venv\Scripts\activate # Windows
```
3. **Install dependencies:**
```bash
pip install -r requirements.txt
```
4. **Install Playwright browsers:**
```bash
patchright install --with-deps # Installs browsers and OS dependencies
```
### Configuration
1. Rename the .env.example to .env file in the project root directory.
2. Add your LLM API key and other necessary details:
```dotenv
# .env
LLM_API_KEY="YOUR_LLM_API_KEY"
```
* Replace `YOUR_LLM_API_KEY` with your actual key.
### Adding the MCP Server
Add this to you mcp config:
```json
{
"mcpServers": {
"VibeShift":{
"command": "uv",
"args": ["--directory","path/to/cloned_repo", "run", "mcp_server.py"]
}
}
}
```
Keep this server running while you interact with your AI coding assistant.
## Usage
Interact with the agent through your MCP-enabled AI coding assistant using natural language.
**Examples:**
* **Security Analysis:**
* **Automatic (Preferred):** VibeShift automatically analyzes code snippets generated or significantly modified by the AI assistant.
* **Explicit Commands:**
> "VibeShift, analyze this function for security vulnerabilities."
> "Ask VibeShift to check the Python code Copilot just wrote for SQL injection."
> "Secure the generated code with VibeShift before committing."
* **Record a Test:**
> "Record a test: go to https://practicetestautomation.com/practice-test-login/, type 'student' into the username field, type 'Password123' into the password field, click the submit button, and verify the text 'Congratulations student' is visible."
* *(The agent will perform these actions automatically and save a `test_....json` file in `output/`)*
* **Execute a Test:**
> "Run the regression test `output/test_practice_test_login_20231105_103000.json`"
* *(The agent will execute the steps in the specified file and report PASS/FAIL status with errors and details.)*
* **Discover Test Steps:**
> "Discover potential test steps starting from https://practicetestautomation.com/practice/"
* *(The agent will crawl the site, analyze pages, and return suggested test steps for each.)*
* **List Recorded Tests:**
> "List the available recorded web tests."
* *(The agent will return a list of `.json` files found in the `output/` directory.)*
**Output:**
* **Security Reports:** Returned to the AI coding assistant, detailing:
* Vulnerability type (e.g., CWE, OWASP category)
* Location in code
* Severity
* Evidence / Explanation
* Suggested remediations (often for the AI to action)
* **Recorded Tests:** Saved as JSON files in the `output/` directory (see `test_schema.md` for format).
* **Execution Results:** Returned as a JSON object summarizing the run (status, errors, evidence paths). Full results are also saved to `output/execution_result_....json`.
* **Discovery Results:** Returned as a JSON object with discovered URLs and suggested steps. Full results saved to `output/discovery_results_....json`.
## Inspiration
* **[Browser Use](https://github.com/browser-use/browser-use/)**: The dom context tree generation is heavily inspired from them and is modified to accomodate static/dynamic/visual elements. Special thanks to them for their contribution to open source.
* **[Semgrep](https://github.com/returntocorp/semgrep)**: A powerful open-source static analysis tool we leverage.
* **[Nuclei](https://github.com/projectdiscovery/nuclei)**: For template-based dynamic scanning capabilities.
## Contributing
We welcome contributions! Please see `CONTRIBUTING.md` for details on how to get started, report issues, and submit pull requests. We're particularly interested in:
* New security analyzer integrations.
## License
This project is licensed under the [APACHE-2.0](LICENSE).
```
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
```markdown
# Contributing to the AI Web Testing Agent
First off, thank you for considering contributing! This project aims to improve the development workflow by integrating automated web testing directly with AI coding assistants. Your contributions can make a real difference.
## Code of Conduct
This project and everyone participating in it is governed by the [Code of Conduct](CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior.
## How Can I Contribute?
There are many ways to contribute, from reporting bugs to implementing new features.
### Reporting Bugs
* Ensure the bug was not already reported by searching on GitHub under [Issues](https://github.com/Ilikepizza2/GroundNG/issues).
* If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/Ilikepizza2/GroundNG/issues/new). Be sure to include a **title and clear description**, as much relevant information as possible, and a **code sample or an executable test case** demonstrating the expected behavior that is not occurring.
* Include details about your environment (OS, Python version, library versions).
### Suggesting Enhancements
* Open a new issue to suggest an enhancement. Provide a clear description of the enhancement and its potential benefits.
* Explain why this enhancement would be useful and provide examples if possible.
### Pull Requests
1. **Fork the repository** on GitHub.
2. **Clone your fork** locally: `git clone [email protected]:Ilikepizza2/GroundNG.git`
3. **Create a virtual environment** and install dependencies:
```bash
cd <repository-name>
python -m venv venv
source venv/bin/activate # Or venv\Scripts\activate on Windows
pip install -r requirements.txt
playwright install --with-deps
```
4. **Create a topic branch** for your changes: `git checkout -b feature/your-feature-name` or `git checkout -b fix/your-bug-fix`.
5. **Make your changes.** Write clean, readable code. Add comments where necessary.
6. **Add tests** for your changes. Ensure existing tests pass. (See Testing section below).
7. **Format your code** (e.g., using Black): `black .`
8. **Commit your changes** using a descriptive commit message. Consider using [Conventional Commits](https://www.conventionalcommits.org/).
9. **Push your branch** to your fork on GitHub: `git push origin feature/your-feature-name`.
10. **Open a Pull Request** to the `main` branch of the original repository. Provide a clear description of your changes and link any relevant issues.
## Development Setup
* Follow the installation steps in the [README.md](README.md).
* Ensure you have a `.env` file set up with your LLM API key for running the agent components that require it.
* Use the `mcp dev mcp_server.py` command to run the server locally for testing MCP interactions.
## Testing
* This project uses `pytest`. Run tests using:
```bash
pytest
```
* Please add tests for any new features or bug fixes. Place tests in a `tests/` directory (if not already present).
* Ensure all tests pass before submitting a pull request.
## Code Style
* Please follow PEP 8 guidelines.
* We recommend using [Black](https://github.com/psf/black) for code formatting. Run `black .` before committing.
* Use clear and descriptive variable and function names.
* Add docstrings to modules, classes, and functions.
## Questions?
If you have questions about contributing or the project in general, feel free to open an issue on GitHub.
Thank you for your contribution!
```
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
```markdown
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at no one available :(
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder][mozilla coc].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][faq]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[mozilla coc]: https://github.com/mozilla/diversity
[faq]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations
```
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/src/agents/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/src/browser/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/src/core/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/src/execution/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/src/llm/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/src/security/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
```python
```
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
```
playwright
google-genai
beautifulsoup4
python-dotenv
Pillow
pydantic>=2.0
pytest
pytest-html
mcp[cli]
pixelmatch>=0.3.0
openai
patchright
requests
semgrep
```
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
```python
# /src/utils/utils.py
import os
from dotenv import load_dotenv
def load_api_key():
"""Loads the llm API key from .env file."""
load_dotenv()
api_key = os.getenv("LLM_API_KEY")
if not api_key:
raise ValueError("LLM_API_KEY not found in .env file or environment variables.")
return api_key
def load_api_base_url():
"""Loads the API base url from .env file."""
load_dotenv()
base_url = os.getenv("LLM_BASE_URL")
if not base_url:
raise ValueError("LLM_BASE_URL not found in .env file or environment variables.")
return base_url
def load_api_version():
"""Loads the API Version from .env file."""
load_dotenv()
api_version = os.getenv("LLM_API_VERSION")
if not api_version:
raise ValueError("LLM_API_VERSION not found in .env file or environment variables.")
return api_version
def load_llm_model():
"""Loads the llm model from .env file."""
load_dotenv()
llm_model = os.getenv("LLM_MODEL")
if not llm_model:
raise ValueError("LLM_MODEL not found in .env file or environment variables.")
return llm_model
def load_llm_timeout():
"""Loads the default llm model timeout from .env file."""
load_dotenv()
llm_timeout = os.getenv("LLM_TIMEOUT")
if not llm_timeout:
raise ValueError("LLM_TIMEOUT not found in .env file or environment variables.")
return llm_timeout
```
--------------------------------------------------------------------------------
/src/agents/js_utils/xpathgenerator.js:
--------------------------------------------------------------------------------
```javascript
function generateXPathForElement(currentElement) {
function getElementPosition(currentElement) {
if (!currentElement.parentElement) return 0;
const tagName = currentElement.nodeName.toLowerCase();
const siblings = Array.from(currentElement.parentElement.children)
.filter((sib) => sib.nodeName.toLowerCase() === tagName);
if (siblings.length === 1) return 0;
const index = siblings.indexOf(currentElement) + 1;
return index;
}
const segments = [];
let elementToProcess = currentElement;
while (elementToProcess && elementToProcess.nodeType === Node.ELEMENT_NODE) {
const position = getElementPosition(elementToProcess);
const tagName = elementToProcess.nodeName.toLowerCase();
const xpathIndex = position > 0 ? `[${position}]` : "";
segments.unshift(`${tagName}${xpathIndex}`);
const parentNode = elementToProcess.parentNode;
if (!parentNode || parentNode.nodeType !== Node.ELEMENT_NODE) {
elementToProcess = null;
} else if (parentNode instanceof ShadowRoot || parentNode instanceof HTMLIFrameElement) {
elementToProcess = null;
} else {
elementToProcess = parentNode;
}
}
let finalPath = segments.join("/");
if (finalPath && !finalPath.startsWith('html') && !finalPath.startsWith('/html')) {
if (finalPath.startsWith('body')) {
finalPath = '/html/' + finalPath;
} else if (!finalPath.startsWith('/')) {
finalPath = '/' + finalPath;
}
} else if (finalPath.startsWith('body')) {
finalPath = '/html/' + finalPath;
}
return finalPath || null;
}
```
--------------------------------------------------------------------------------
/src/security/utils.py:
--------------------------------------------------------------------------------
```python
# utils.py
import logging
import json
import os
from datetime import datetime
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(message)s'
def setup_logging(log_level=logging.INFO):
"""Configures basic logging."""
logging.basicConfig(level=log_level, format=LOG_FORMAT)
def save_report(data, tool_name, output_dir="results", filename_prefix="report"):
"""Saves the collected data to a JSON file."""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{filename_prefix}_{tool_name}_{timestamp}.json"
filepath = os.path.join(output_dir, filename)
try:
with open(filepath, 'w') as f:
json.dump(data, f, indent=4)
logging.info(f"Successfully saved {tool_name} report to {filepath}")
return filepath
except Exception as e:
logging.error(f"Failed to save {tool_name} report to {filepath}: {e}")
return None
def parse_json_lines_file(filepath):
"""Parses a file containing JSON objects, one per line."""
results = []
if not os.path.exists(filepath):
logging.error(f"File not found for parsing: {filepath}")
return results
try:
with open(filepath, 'r') as f:
for line in f:
try:
if line.strip():
results.append(json.loads(line))
except json.JSONDecodeError as e:
logging.warning(f"Skipping invalid JSON line in {filepath}: {line.strip()} - Error: {e}")
return results
except Exception as e:
logging.error(f"Failed to read or parse JSON lines file {filepath}: {e}")
return [] # Return empty list on failure
def parse_json_file(filepath):
"""Parses a standard JSON file."""
if not os.path.exists(filepath):
logging.error(f"File not found for parsing: {filepath}")
return None
try:
with open(filepath, 'r') as f:
data = json.load(f)
return data
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON format in {filepath}: {e}")
return None
except Exception as e:
logging.error(f"Failed to read or parse JSON file {filepath}: {e}")
return None
```
--------------------------------------------------------------------------------
/src/dom/history/view.py:
--------------------------------------------------------------------------------
```python
# /src/dom/history/view.py
from dataclasses import dataclass
from typing import Optional, List, Dict, Union # Added Dict
# Use Pydantic for coordinate models if available and desired
try:
from pydantic import BaseModel, Field
class Coordinates(BaseModel):
x: float # Use float for potential subpixel values
y: float
class CoordinateSet(BaseModel):
# Match names from buildDomTree.js if they differ
top_left: Coordinates
top_right: Coordinates
bottom_left: Coordinates
bottom_right: Coordinates
center: Coordinates
width: float
height: float
class ViewportInfo(BaseModel):
scroll_x: float = Field(alias="scrollX") # Match JS key if needed
scroll_y: float = Field(alias="scrollY")
width: float
height: float
except ImportError:
# Fallback if Pydantic is not installed (less type safety)
Coordinates = Dict[str, float]
CoordinateSet = Dict[str, Union[Coordinates, float]]
ViewportInfo = Dict[str, float]
BaseModel = object # Placeholder
@dataclass
class HashedDomElement:
""" Hash components of a DOM element for comparison. """
branch_path_hash: str
attributes_hash: str
xpath_hash: str
# text_hash: str (Still excluded)
@dataclass
class DOMHistoryElement:
""" A serializable representation of a DOM element's state at a point in time. """
tag_name: str
xpath: str
highlight_index: Optional[int]
entire_parent_branch_path: List[str]
attributes: Dict[str, str]
shadow_root: bool = False
css_selector: Optional[str] = None # Generated enhanced selector
# Store the Pydantic models or dicts directly
page_coordinates: Optional[CoordinateSet] = None
viewport_coordinates: Optional[CoordinateSet] = None
viewport_info: Optional[ViewportInfo] = None
def to_dict(self) -> dict:
""" Converts the history element to a dictionary. """
data = {
'tag_name': self.tag_name,
'xpath': self.xpath,
'highlight_index': self.highlight_index,
'entire_parent_branch_path': self.entire_parent_branch_path,
'attributes': self.attributes,
'shadow_root': self.shadow_root,
'css_selector': self.css_selector,
# Handle Pydantic models correctly if used
'page_coordinates': self.page_coordinates.model_dump() if isinstance(self.page_coordinates, BaseModel) else self.page_coordinates,
'viewport_coordinates': self.viewport_coordinates.model_dump() if isinstance(self.viewport_coordinates, BaseModel) else self.viewport_coordinates,
'viewport_info': self.viewport_info.model_dump() if isinstance(self.viewport_info, BaseModel) else self.viewport_info,
}
# Filter out None values if desired
# return {k: v for k, v in data.items() if v is not None}
return data
```
--------------------------------------------------------------------------------
/src/security/nuclei_scanner.py:
--------------------------------------------------------------------------------
```python
# nuclei_scanner.py
import logging
import subprocess
import os
import shlex
from datetime import datetime
from .utils import parse_json_file # Relative import
NUCLEI_TIMEOUT_SECONDS = 900 # 15 minutes default
def run_nuclei(target_url: str, output_dir="results", timeout=NUCLEI_TIMEOUT_SECONDS, severity="low,medium,high,critical"):
"""Runs the Nuclei security scanner against a target URL or IP."""
if not target_url:
logging.error("Nuclei target URL/IP is required")
return []
logging.info(f"Starting Nuclei scan for target: {target_url}")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"nuclei_output_{timestamp}.json"
output_filepath = os.path.join(output_dir, output_filename)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Configure nuclei command with common best practices
command = [
"nuclei",
"-target", target_url,
"-json",
"-o", output_filepath,
"-severity", severity,
"-silent"
]
logging.debug(f"Executing Nuclei command: {' '.join(shlex.quote(cmd) for cmd in command)}")
try:
result = subprocess.run(command, capture_output=True, text=True, timeout=timeout, check=False)
logging.info("Nuclei process finished.")
logging.debug(f"Nuclei stdout:\n{result.stdout}")
if result.returncode != 0:
logging.warning(f"Nuclei exited with non-zero status code: {result.returncode}")
return [f"Nuclei exited with non-zero status code: {result.returncode}"]
# Parse the JSON output file
findings = parse_json_file(output_filepath)
if findings:
logging.info(f"Successfully parsed {len(findings)} findings from Nuclei output.")
# Add tool name for context
for finding in findings:
finding['tool'] = 'Nuclei'
# Standardize some fields to match our expected format
if 'info' in finding:
finding['severity'] = finding.get('info', {}).get('severity')
finding['message'] = finding.get('info', {}).get('name')
finding['description'] = finding.get('info', {}).get('description')
finding['matched_at'] = finding.get('matched-at', '')
return findings
else:
logging.warning(f"Could not parse findings from Nuclei output file: {output_filepath}")
return [f"Could not parse findings from Nuclei output file: {output_filepath}"]
except subprocess.TimeoutExpired:
logging.error(f"Nuclei scan timed out after {timeout} seconds.")
return [f"Nuclei scan timed out after {timeout} seconds."]
except FileNotFoundError:
logging.error("Nuclei command not found. Is Nuclei installed and in PATH?")
return ["Nuclei command not found. Is Nuclei installed and in PATH?"]
except Exception as e:
logging.error(f"An unexpected error occurred while running Nuclei: {e}")
return [f"An unexpected error occurred while running Nuclei: {e}"]
```
--------------------------------------------------------------------------------
/src/security/semgrep_scanner.py:
--------------------------------------------------------------------------------
```python
# semgrep_scanner.py
import logging
import subprocess
import os
import shlex
from datetime import datetime
from .utils import parse_json_file # Relative import
SEMGREP_TIMEOUT_SECONDS = 600 # 10 minutes default
def run_semgrep(code_path: str, config: str = "auto", output_dir="results", timeout=SEMGREP_TIMEOUT_SECONDS):
"""Runs the Semgrep CLI tool."""
if not os.path.isdir(code_path):
logging.error(f"Semgrep target path is not a valid directory: {code_path}")
return []
logging.info(f"Starting Semgrep scan for codebase: {code_path} using config: {config}")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"semgrep_output_{timestamp}.json"
output_filepath = os.path.join(output_dir, output_filename)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Use --json for machine-readable output
command = ["semgrep", "scan", "--config", config, "--json", "-o", output_filepath, code_path]
logging.debug(f"Executing Semgrep command: {' '.join(shlex.quote(cmd) for cmd in command)}")
try:
result = subprocess.run(command, capture_output=True, text=True, timeout=timeout, check=False) # check=False
logging.info("Semgrep process finished.")
logging.debug(f"Semgrep stdout:\n{result.stdout}") # Often has progress info
# if result.stderr:
# logging.warning(f"Semgrep stderr:\n{result.stderr}")
# return [f"semgrep stderr: \n{result.stderr}"]
if result.returncode != 0:
logging.warning(f"Semgrep exited with non-zero status code: {result.returncode}")
return [f"Semgrep exited with non-zero status code: {result.returncode}"]
# It might still produce output even with errors (e.g., parse errors)
# Parse the JSON output file
report_data = parse_json_file(output_filepath)
if report_data and "results" in report_data:
findings = report_data["results"]
logging.info(f"Successfully parsed {len(findings)} findings from Semgrep output.")
# Add tool name for context
for finding in findings:
finding['tool'] = 'Semgrep'
# Simplify structure slightly if needed
finding['message'] = finding.get('extra', {}).get('message')
finding['severity'] = finding.get('extra', {}).get('severity')
finding['code_snippet'] = finding.get('extra', {}).get('lines')
return findings
else:
logging.warning(f"Could not parse findings from Semgrep output file: {output_filepath}")
return [f"Could not parse findings from Semgrep output file: {output_filepath}"]
except subprocess.TimeoutExpired:
logging.error(f"Semgrep scan timed out after {timeout} seconds.")
return [f"Semgrep scan timed out after {timeout} seconds."]
except FileNotFoundError:
logging.error("Semgrep command not found. Is Semgrep installed and in PATH?")
return ["Semgrep command not found. Is Semgrep installed and in PATH?"]
except Exception as e:
logging.error(f"An unexpected error occurred while running Semgrep: {e}")
return [f"An unexpected error occurred while running Semgrep: {e}"]
```
--------------------------------------------------------------------------------
/src/llm/llm_client.py:
--------------------------------------------------------------------------------
```python
# /src/llm/lm_client.py
from google import genai
from PIL import Image
import io
import logging
import time # Import time module
import threading # Import threading for lock
from typing import Type, Optional, Union, List, Dict, Any
logger = logging.getLogger(__name__)
import base64
import json
from .clients.gemini_client import GeminiClient
from .clients.azure_openai_client import AzureOpenAIClient
from .clients.openai_client import OpenAIClient
class LLMClient:
"""
Handles interactions with LLM APIs (Google Gemini or any LLM with OpenAI sdk)
with rate limiting.
"""
# Rate limiting parameters (adjust based on the specific API limits)
# Consider making this provider-specific if needed
MIN_REQUEST_INTERVAL_SECONDS = 3.0 # Adjusted slightly, Gemini free is 15 RPM (4s), LLM depends on tier
def __init__(self, provider: str):# 'gemini' or 'LLM'
"""
Initializes the LLM client for the specified provider.
Args:
provider: The LLM provider to use ('gemini' or 'openai' or 'azure').
"""
self.provider = provider.lower()
self.client = None
if self.provider == 'gemini':
self.client = GeminiClient()
elif self.provider == 'openai':
self.client = OpenAIClient()
elif self.provider == 'azure':
self.client = AzureOpenAIClient()
else:
raise ValueError(f"Unsupported provider: {provider}. Choose 'gemini' or 'openai' or 'azure'.")
# Common initialization
self._last_request_time = 0.0
self._lock = threading.Lock() # Lock for rate limiting
logger.info(f"LLMClient initialized for provider '{self.provider}' with {self.MIN_REQUEST_INTERVAL_SECONDS}s request interval.")
def _wait_for_rate_limit(self):
"""Waits if necessary to maintain the minimum request interval."""
with self._lock: # Ensure thread-safe access
now = time.monotonic()
elapsed = now - self._last_request_time
wait_time = self.MIN_REQUEST_INTERVAL_SECONDS - elapsed
if wait_time > 0:
logger.debug(f"Rate limiting: Waiting for {wait_time:.2f} seconds...")
time.sleep(wait_time)
self._last_request_time = time.monotonic() # Update after potential wait
def generate_text(self, prompt: str) -> str:
"""Generates text using the configured LLM provider, respecting rate limits."""
self._wait_for_rate_limit() # Wait before making the API call
return self.client.generate_text(prompt)
def generate_multimodal(self, prompt: str, image_bytes: bytes) -> str:
"""Generates text based on a prompt and an image, respecting rate limits."""
self._wait_for_rate_limit() # Wait before making the API call
return self.client.generate_multimodal(prompt, image_bytes)
def generate_json(self, Schema_Class: Type, prompt: str, image_bytes: Optional[bytes] = None) -> Union[Dict[str, Any], str]:
"""
Generates structured JSON output based on a prompt, an optional image,
and a defined schema, respecting rate limits.
For Gemini, Schema_Class should be a Pydantic BaseModel or compatible type.
For any other LLM, Schema_Class must be a Pydantic BaseModel.
Returns:
A dictionary representing the parsed JSON on success, or an error string.
"""
self._wait_for_rate_limit()
return self.client.generate_json(Schema_Class, prompt, image_bytes)
```
--------------------------------------------------------------------------------
/test_schema.md:
--------------------------------------------------------------------------------
```markdown
```json
// output/test_case_example.json
{
"test_name": "Login Functionality Test",
"feature_description": "User logs in with valid credentials and verifies the welcome message.",
"recorded_at": "2023-10-27T10:00:00Z",
"steps": [
{
"step_id": 1,
"action": "navigate",
"description": "Navigate to the login page", // Natural language
"parameters": {
"url": "https://practicetestautomation.com/practice-test-login/"
},
"selector": null, // Not applicable
"wait_after_secs": 1.0 // Optional: Simple wait after action
},
{
"step_id": 2,
"action": "type",
"description": "Type username 'student'",
"parameters": {
"text": "student",
"parameter_name": "username" // Optional: For parameterization
},
"selector": "#username", // Recorded robust selector
"wait_after_secs": 0.5
},
{
"step_id": 3,
"action": "type",
"description": "Type password 'Password123'",
"parameters": {
"text": "Password123",
"parameter_name": "password" // Optional: For parameterization
},
"selector": "input[name='password']",
"wait_after_secs": 0.5
},
{
"step_id": 4,
"action": "click",
"description": "Click the submit button",
"parameters": {},
"selector": "button#submit",
"wait_after_secs": 1.0 // Longer wait after potential navigation/update
},
{
"step_id": 5,
"action": "wait_for_load_state", // Explicit wait example
"description": "Wait for page load after submit",
"parameters": {
"state": "domcontentloaded" // Or "load", "networkidle"
},
"selector": null,
"wait_after_secs": 0
},
{
"step_id": 6,
"action": "assert_text_contains",
"description": "Verify success message is shown",
"parameters": {
"expected_text": "Congratulations student. You successfully logged in!"
},
"selector": "div.post-content p strong", // Selector for the element containing the text
"wait_after_secs": 0
},
{
"step_id": 7,
"action": "assert_visible",
"description": "Verify logout button is visible",
"parameters": {},
"selector": "a.wp-block-button__link:has-text('Log out')",
"wait_after_secs": 0
},
{
"step_id": 8, // Example ID
"action": "select",
"description": "Select 'Weekly' notification frequency",
"parameters": {
"option_label": "Weekly" // Store the label (or value if preferred)
// "parameter_name": "notification_pref" // Optional parameterization
},
"selector": "select#notificationFrequency", // Selector for the <select> element
"wait_after_secs": 0.5
},
{
"step_id": 9, // Example ID
"action": "assert_passed_verification", // Special action
"description": "Verify user avatar is displayed in header", // Original goal
"parameters": {
// Optional: might include reasoning from recorder's AI
"reasoning": "The avatar image was visually confirmed present by the vision LLM during recording."
},
"selector": null, // No specific selector needed for executor's check
"wait_after_secs": 0
// NOTE: During execution, the TestExecutor will take a screenshot
// and use its own vision LLM call to re-verify the condition
// described in the 'description' field. It passes if the LLM
// confirms visually, otherwise it fails the test.
}
// ... more steps
]
}
```
```
--------------------------------------------------------------------------------
/src/utils/image_utils.py:
--------------------------------------------------------------------------------
```python
from PIL import Image, ImageDraw, ImageFont
import io
import logging
import base64
from typing import Optional
import os
from ..llm.llm_client import LLMClient
logger = logging.getLogger(__name__)
# Helper Function
def stitch_images(img1: Image.Image, img2: Image.Image, label1="Baseline", label2="Current") -> Optional[Image.Image]:
"""Stitches two images side-by-side with labels."""
if img1.size != img2.size:
logger.error("Cannot stitch images of different sizes.")
return None
width1, height1 = img1.size
width2, height2 = img2.size # Should be same as height1
# Add padding for labels
label_height = 30 # Adjust as needed
total_width = width1 + width2
total_height = height1 + label_height
stitched_img = Image.new('RGBA', (total_width, total_height), (255, 255, 255, 255)) # White background
# Paste images
stitched_img.paste(img1, (0, label_height))
stitched_img.paste(img2, (width1, label_height))
# Add labels
try:
draw = ImageDraw.Draw(stitched_img)
# Attempt to load a simple font (adjust path or use default if needed)
try:
# On Linux/macOS, common paths
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
if not os.path.exists(font_path): font_path = "/System/Library/Fonts/Supplemental/Arial Bold.ttf" # macOS fallback
font = ImageFont.truetype(font_path, 15)
except IOError:
logger.warning("Default font not found, using Pillow's default.")
font = ImageFont.load_default()
# Label 1 (Baseline)
label1_pos = (10, 5)
draw.text(label1_pos, f"1: {label1}", fill=(0, 0, 0, 255), font=font)
# Label 2 (Current)
label2_pos = (width1 + 10, 5)
draw.text(label2_pos, f"2: {label2}", fill=(0, 0, 0, 255), font=font)
except Exception as e:
logger.warning(f"Could not add labels to stitched image: {e}")
# Return image without labels if drawing fails
stitched_img.save("./stitched.png")
return stitched_img
def compare_images(prompt: str, image_bytes_1: bytes, image_bytes_2: bytes, image_client: LLMClient) -> str:
"""
Compares two images using the multimodal LLM based on the prompt,
by stitching them into a single image first.
"""
logger.info("Preparing images for stitched comparison...")
try:
img1 = Image.open(io.BytesIO(image_bytes_1)).convert("RGBA")
img2 = Image.open(io.BytesIO(image_bytes_2)).convert("RGBA")
if img1.size != img2.size:
error_msg = f"Visual Comparison Failed: Image dimensions mismatch. Baseline: {img1.size}, Current: {img2.size}."
logger.error(error_msg)
return f"Error: {error_msg}" # Return error directly
stitched_image_pil = stitch_images(img1, img2)
if not stitched_image_pil:
return "Error: Failed to stitch images."
# Convert stitched image to bytes
stitched_buffer = io.BytesIO()
stitched_image_pil.save(stitched_buffer, format="PNG")
stitched_image_bytes = stitched_buffer.getvalue()
logger.info(f"Images stitched successfully (new size: {stitched_image_pil.size}). Requesting LLM comparison...")
except Exception as e:
logger.error(f"Error processing images for stitching: {e}", exc_info=True)
return f"Error: Image processing failed - {e}"
return image_client.generate_multimodal(prompt, stitched_image_bytes)
```
--------------------------------------------------------------------------------
/src/llm/clients/gemini_client.py:
--------------------------------------------------------------------------------
```python
# /src/llm/clients/gemini_client.py
from google import genai
from PIL import Image
import io
import logging
import time # Import time module
import threading # Import threading for lock
from typing import Type, Optional, Union, List, Dict, Any
logger = logging.getLogger(__name__)
import base64
import json
from ...utils.utils import load_api_key
class GeminiClient:
def __init__(self):
self.client = None
gemini_api_key = load_api_key()
if not gemini_api_key:
raise ValueError("gemini_api_key is required for provider 'gemini'")
try:
# genai.configure(api_key=gemini_api_key) # configure is global, prefer Client
self.client = genai.Client(api_key=gemini_api_key)
# Test connection slightly by listing models (optional)
# list(self.client.models.list())
logger.info("Google Gemini Client initialized.")
except Exception as e:
logger.error(f"Failed to initialize Google Gemini Client: {e}", exc_info=True)
raise RuntimeError(f"Gemini client initialization failed: {e}")
def generate_text(self, prompt: str) -> str:
"""Generates text using the Gemini text model, respecting rate limits."""
try:
# Truncate prompt for logging if too long
log_prompt = prompt[:200] + ('...' if len(prompt) > 200 else '')
logger.debug(f"Sending text prompt (truncated): {log_prompt}")
# response = self.text_model.generate_content(prompt)
response = self.client.models.generate_content(
model='gemini-2.0-flash',
contents=prompt
)
logger.debug("Received text response.")
# Improved response handling
if hasattr(response, 'text'):
return response.text
elif response.parts:
# Sometimes response might be in parts without direct .text attribute
return "".join(part.text for part in response.parts if hasattr(part, 'text'))
elif response.prompt_feedback and response.prompt_feedback.block_reason:
block_reason = response.prompt_feedback.block_reason
block_message = f"Error: Content generation blocked due to {block_reason}"
if response.prompt_feedback.safety_ratings:
block_message += f" - Safety Ratings: {response.prompt_feedback.safety_ratings}"
logger.warning(block_message)
return block_message
else:
logger.warning(f"Text generation returned no text/parts and no block reason. Response: {response}")
return "Error: Empty or unexpected response from LLM."
except Exception as e:
logger.error(f"Error during Gemini text generation: {e}", exc_info=True)
return f"Error: Failed to communicate with Gemini API - {type(e).__name__}: {e}"
def generate_multimodal(self, prompt: str, image_bytes: bytes) -> str:
"""Generates text based on a prompt and an image, respecting rate limits."""
try:
log_prompt = prompt[:200] + ('...' if len(prompt) > 200 else '')
# logger.debug(f"Sending multimodal prompt (truncated): {log_prompt} with image.")
image = Image.open(io.BytesIO(image_bytes))
# response = self.vision_model.generate_content([prompt, image])
response = self.client.models.generate_content(
model='gemini-2.0-flash',
contents=[
prompt,
image
]
)
logger.debug("Received multimodal response.")
# Improved response handling (similar to text)
if hasattr(response, 'text'):
return response.text
elif response.parts:
return "".join(part.text for part in response.parts if hasattr(part, 'text'))
elif response.prompt_feedback and response.prompt_feedback.block_reason:
block_reason = response.prompt_feedback.block_reason
block_message = f"Error: Multimodal generation blocked due to {block_reason}"
if response.prompt_feedback.safety_ratings:
block_message += f" - Safety Ratings: {response.prompt_feedback.safety_ratings}"
logger.warning(block_message)
return block_message
else:
logger.warning(f"Multimodal generation returned no text/parts and no block reason. Response: {response}")
return "Error: Empty or unexpected response from Vision LLM."
except Exception as e:
logger.error(f"Error during Gemini multimodal generation: {e}", exc_info=True)
return f"Error: Failed to communicate with Gemini Vision API - {type(e).__name__}: {e}"
def generate_json(self, Schema_Class: Type, prompt: str, image_bytes: Optional[bytes] = None) -> Union[Dict[str, Any], str]:
"""generates json based on prompt and a defined schema"""
contents = prompt
if(image_bytes is not None):
image = Image.open(io.BytesIO(image_bytes))
contents = [prompt, image]
try:
log_prompt = prompt[:200] + ('...' if len(prompt) > 200 else '')
logger.debug(f"Sending text prompt (truncated): {log_prompt}")
response = self.client.models.generate_content(
model='gemini-2.0-flash',
contents=contents,
config={
'response_mime_type': 'application/json',
'response_schema': Schema_Class
}
)
logger.debug("Received json response from LLM")
if hasattr(response, 'parsed'):
return response.parsed
elif response.prompt_feedback and response.prompt_feedback.block_reason:
block_reason = response.prompt_feedback.block_reason
block_message = f"Error: JSON generation blocked due to {block_reason}"
if response.prompt_feedback.safety_ratings:
block_message += f" - Safety Ratings: {response.prompt_feedback.safety_ratings}"
logger.warning(block_message)
return block_message
else:
logger.warning(f"JSON generation returned no text/parts and no block reason. Response: {response}")
return "Error: Empty or unexpected response from JSON LLM."
except Exception as e:
logger.error(f"Error during Gemini JSON generation: {e}", exc_info=True)
return f"Error: Failed to communicate with Gemini JSON API - {type(e).__name__}: {e}"
```
--------------------------------------------------------------------------------
/src/dom/history/service.py:
--------------------------------------------------------------------------------
```python
# /src/dom/history/service.py
import hashlib
from typing import Optional, List, Dict # Added Dict
# Use relative imports
from ..views import DOMElementNode # Import from parent package's views
from .view import DOMHistoryElement, HashedDomElement # Import from sibling view
# Requires BrowserContext._enhanced_css_selector_for_element
# This needs to be available. Let's assume DomService provides it statically for now.
from ..service import DomService
class HistoryTreeProcessor:
"""
Operations for comparing DOM elements across different states using hashing.
"""
@staticmethod
def convert_dom_element_to_history_element(dom_element: DOMElementNode) -> DOMHistoryElement:
"""Converts a live DOMElementNode to a serializable DOMHistoryElement."""
if not dom_element: return None # Added safety check
parent_branch_path = HistoryTreeProcessor._get_parent_branch_path(dom_element)
# Use the static method from DomService to generate the selector
css_selector = DomService._enhanced_css_selector_for_element(dom_element)
# Ensure coordinate/viewport data is copied correctly
page_coords = dom_element.page_coordinates.model_dump() if dom_element.page_coordinates else None
viewport_coords = dom_element.viewport_coordinates.model_dump() if dom_element.viewport_coordinates else None
viewport_info = dom_element.viewport_info.model_dump() if dom_element.viewport_info else None
return DOMHistoryElement(
tag_name=dom_element.tag_name,
xpath=dom_element.xpath,
highlight_index=dom_element.highlight_index,
entire_parent_branch_path=parent_branch_path,
attributes=dom_element.attributes,
shadow_root=dom_element.shadow_root,
css_selector=css_selector, # Use generated selector
# Pass the Pydantic models directly if DOMHistoryElement expects them
page_coordinates=dom_element.page_coordinates,
viewport_coordinates=dom_element.viewport_coordinates,
viewport_info=dom_element.viewport_info,
)
@staticmethod
def find_history_element_in_tree(dom_history_element: DOMHistoryElement, tree: DOMElementNode) -> Optional[DOMElementNode]:
"""Finds an element in a new DOM tree that matches a historical element."""
if not dom_history_element or not tree: return None
hashed_dom_history_element = HistoryTreeProcessor._hash_dom_history_element(dom_history_element)
# Define recursive search function
def process_node(node: DOMElementNode) -> Optional[DOMElementNode]:
if not isinstance(node, DOMElementNode): # Skip non-element nodes
return None
# Only hash and compare elements that could potentially match (e.g., have attributes/xpath)
# Optimization: maybe check tag_name first?
if node.tag_name == dom_history_element.tag_name:
hashed_node = HistoryTreeProcessor._hash_dom_element(node)
if hashed_node == hashed_dom_history_element:
# Found a match based on hash
# Optional: Add secondary checks here if needed (e.g., text snippet)
return node
# Recursively search children
for child in node.children:
# Important: Only recurse into DOMElementNode children
if isinstance(child, DOMElementNode):
result = process_node(child)
if result is not None:
return result # Return immediately if found in subtree
return None # Not found in this branch
return process_node(tree)
@staticmethod
def compare_history_element_and_dom_element(dom_history_element: DOMHistoryElement, dom_element: DOMElementNode) -> bool:
"""Compares a historical element and a live element using hashes."""
if not dom_history_element or not dom_element: return False
hashed_dom_history_element = HistoryTreeProcessor._hash_dom_history_element(dom_history_element)
hashed_dom_element = HistoryTreeProcessor._hash_dom_element(dom_element)
return hashed_dom_history_element == hashed_dom_element
@staticmethod
def _hash_dom_history_element(dom_history_element: DOMHistoryElement) -> Optional[HashedDomElement]:
"""Generates a hash object from a DOMHistoryElement."""
if not dom_history_element: return None
# Use the stored parent path
branch_path_hash = HistoryTreeProcessor._parent_branch_path_hash(dom_history_element.entire_parent_branch_path)
attributes_hash = HistoryTreeProcessor._attributes_hash(dom_history_element.attributes)
xpath_hash = HistoryTreeProcessor._xpath_hash(dom_history_element.xpath)
return HashedDomElement(branch_path_hash, attributes_hash, xpath_hash)
@staticmethod
def _hash_dom_element(dom_element: DOMElementNode) -> Optional[HashedDomElement]:
"""Generates a hash object from a live DOMElementNode."""
if not dom_element: return None
parent_branch_path = HistoryTreeProcessor._get_parent_branch_path(dom_element)
branch_path_hash = HistoryTreeProcessor._parent_branch_path_hash(parent_branch_path)
attributes_hash = HistoryTreeProcessor._attributes_hash(dom_element.attributes)
xpath_hash = HistoryTreeProcessor._xpath_hash(dom_element.xpath)
# text_hash = DomTreeProcessor._text_hash(dom_element) # Text hash still excluded
return HashedDomElement(branch_path_hash, attributes_hash, xpath_hash)
@staticmethod
def _get_parent_branch_path(dom_element: DOMElementNode) -> List[str]:
"""Gets the list of tag names from the element up to the root."""
parents: List[str] = [] # Store tag names directly
current_element: Optional[DOMElementNode] = dom_element
while current_element is not None:
# Prepend tag name to maintain order from root to element
parents.insert(0, current_element.tag_name)
current_element = current_element.parent # Access the parent attribute
# The loop includes the element itself, the definition might imply *excluding* it
# If path should *exclude* the element itself, remove the first element:
# if parents: parents.pop(0) # No, the JS build tree Xpath includes self, let's keep it consistent
return parents
@staticmethod
def _parent_branch_path_hash(parent_branch_path: List[str]) -> str:
"""Hashes the parent branch path string."""
# Normalize: use lowercase tags and join consistently
parent_branch_path_string = '/'.join(tag.lower() for tag in parent_branch_path)
return hashlib.sha256(parent_branch_path_string.encode('utf-8')).hexdigest()
@staticmethod
def _attributes_hash(attributes: Dict[str, str]) -> str:
"""Hashes the element's attributes dictionary."""
# Ensure consistent order by sorting keys
# Normalize attribute values (e.g., strip whitespace?) - Keep simple for now
attributes_string = ''.join(f'{key}={attributes[key]}' for key in sorted(attributes.keys()))
return hashlib.sha256(attributes_string.encode('utf-8')).hexdigest()
@staticmethod
def _xpath_hash(xpath: str) -> str:
"""Hashes the element's XPath."""
# Normalize XPath? (e.g., lowercase tags) - Assume input is consistent for now
return hashlib.sha256(xpath.encode('utf-8')).hexdigest()
# _text_hash remains commented out / unused based on the original code's decision
# @staticmethod
# def _text_hash(dom_element: DOMElementNode) -> str:
# """ """
# text_string = dom_element.get_all_text_till_next_clickable_element()
# return hashlib.sha256(text_string.encode()).hexdigest()
```
--------------------------------------------------------------------------------
/src/core/task_manager.py:
--------------------------------------------------------------------------------
```python
# /src/core/task_manager.py
import logging
from typing import List, Dict, Any, Optional
logger = logging.getLogger(__name__)
class TaskManager:
"""Manages the main task, subtasks, progress, and status."""
def __init__(self, max_retries_per_subtask: int = 2): # Renamed parameter for clarity internally
self.main_task: str = "" # Stores the overall feature description
self.subtasks: List[Dict[str, Any]] = [] # Stores the individual test steps
self.current_subtask_index: int = 0 # Index of the step being processed or next to process
self.max_retries_per_subtask: int = max_retries_per_subtask
logger.info(f"TaskManager (Test Mode) initialized (max_retries_per_step={max_retries_per_subtask}).")
def set_main_task(self, feature_description: str):
"""Sets the main feature description being tested."""
self.main_task = feature_description
self.subtasks = []
self.current_subtask_index = 0
logger.info(f"Feature under test set: {feature_description}")
def add_subtasks(self, test_step_list: List[str]):
"""Adds a list of test steps derived from the feature description."""
if not self.main_task:
logger.error("Cannot add test steps before setting a feature description.")
return
if not isinstance(test_step_list, list) or not all(isinstance(s, str) and s for s in test_step_list):
logger.error(f"Invalid test step list format received: {test_step_list}")
raise ValueError("Test step list must be a non-empty list of non-empty strings.")
self.subtasks = [] # Clear existing steps before adding new ones
for desc in test_step_list:
self.subtasks.append({
"description": desc, # The test step description
"status": "pending", # pending, in_progress, done, failed
"attempts": 0,
"result": None, # Store result of the step (e.g., extracted text)
"error": None, # Store error if the step failed
"_recorded_": False,
"last_failed_selector": None # Store selector if failure was element-related
})
self.current_subtask_index = 0 if self.subtasks else -1 # Reset index
logger.info(f"Added {len(test_step_list)} test steps.")
def insert_subtasks(self, index: int, new_step_descriptions: List[str]):
"""Inserts new test steps at a specific index."""
if not isinstance(new_step_descriptions, list) or not all(isinstance(s, str) and s for s in new_step_descriptions):
logger.error(f"Invalid new step list format received for insertion: {new_step_descriptions}")
return False # Indicate failure
if not (0 <= index <= len(self.subtasks)): # Allow insertion at the end
logger.error(f"Invalid index {index} for inserting subtasks (Total steps: {len(self.subtasks)}).")
return False
new_tasks = []
for desc in new_step_descriptions:
new_tasks.append({
"description": desc,
"status": "pending", # New tasks start as pending
"attempts": 0,
"result": None,
"error": None,
"_recorded_": False, # Ensure internal flags are initialized
"last_failed_selector": None
})
# Insert the new tasks into the list
self.subtasks[index:index] = new_tasks
logger.info(f"Inserted {len(new_tasks)} new subtasks at index {index}.")
# Crucial: If the insertion happens at or before the current index,
# we might need to adjust the current index, but generally, the next call
# to get_next_subtask() should find the newly inserted pending tasks if they
# are before the previously 'current' task. Let get_next_subtask handle finding the next actionable item.
# If insertion happens *after* current processing index, it doesn't immediately affect flow.
return True # Indicate success
def get_next_subtask(self) -> Optional[Dict[str, Any]]:
"""
Gets the first test step that is 'pending' or 'failed' with retries remaining.
Iterates sequentially.
"""
for index, task in enumerate(self.subtasks):
# In recorder mode, 'failed' means AI suggestion failed, allow retry
# In executor mode (if used here), 'failed' means execution failed
is_pending = task["status"] == "pending"
is_retryable_failure = (task["status"] == "failed" and
task["attempts"] <= self.max_retries_per_subtask)
if is_pending or is_retryable_failure:
# Found the next actionable step
if is_retryable_failure:
logger.info(f"Retrying test step {index + 1} (Attempt {task['attempts'] + 1}/{self.max_retries_per_subtask + 1})")
else: # Pending
logger.info(f"Starting test step {index + 1}/{len(self.subtasks)}: {task['description']}")
# Update the main index to point to this task BEFORE changing status
self.current_subtask_index = index
task["status"] = "in_progress"
task["attempts"] += 1
# Keep error context on retry, clear result
task["result"] = None
return task
# No actionable tasks found
logger.info("No more actionable test steps found.")
self.current_subtask_index = len(self.subtasks) # Mark completion
return None
def update_subtask_status(self, index: int, status: str, result: Any = None, error: Optional[str] = None, force_update: bool = False):
"""Updates the status of a specific test step."""
if 0 <= index < len(self.subtasks):
task = self.subtasks[index]
current_status = task["status"]
# Allow update only if forced or if task is 'in_progress'
# if not force_update and task["status"] != "in_progress":
# logger.warning(f"Attempted to update status of test step {index + 1} ('{task['description'][:50]}...') "
# f"from '{task['status']}' to '{status}', but it's not 'in_progress'. Ignoring (unless force_update=True).")
# return
# Log if the status is actually changing
if current_status != status:
logger.info(f"Updating Test Step {index + 1} status from '{current_status}' to '{status}'.")
else:
logger.debug(f"Test Step {index + 1} status already '{status}'. Updating result/error.")
task["status"] = status
task["result"] = result
task["error"] = error
log_message = f"Test Step {index + 1} ('{task['description'][:50]}...') processed. Status: {status}."
if result and status == 'done': log_message += f" Result: {str(result)[:100]}..."
if error: log_message += f" Error/Note: {error}"
# Use debug for potentially repetitive updates if status doesn't change
log_level = logging.INFO if current_status != status else logging.DEBUG
logger.log(log_level, log_message)
# Log permanent failure clearly
if status == "failed" and task["attempts"] > self.max_retries_per_subtask:
logger.warning(f"Test Step {index + 1} failed permanently after {task['attempts']} attempts.")
else:
logger.error(f"Invalid index {index} for updating test step status (Total steps: {len(self.subtasks)}).")
def get_current_subtask(self) -> Optional[Dict[str, Any]]:
"""Gets the test step currently marked by current_subtask_index (likely 'in_progress')."""
if 0 <= self.current_subtask_index < len(self.subtasks):
return self.subtasks[self.current_subtask_index]
return None
def is_complete(self) -> bool:
"""Checks if all test steps have been processed (are 'done' or 'failed' permanently)."""
for task in self.subtasks:
if task['status'] == 'pending' or \
task['status'] == 'in_progress' or \
(task['status'] == 'failed' and task['attempts'] <= self.max_retries_per_subtask):
return False # Found an actionable step
return True # All steps processed
```
--------------------------------------------------------------------------------
/src/agents/auth_agent.py:
--------------------------------------------------------------------------------
```python
# File: record_auth_state_selectors.py
import time
import os
import logging
import getpass
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field
from patchright.sync_api import Error as PlaywrightError, TimeoutError as PlaywrightTimeoutError
# Import necessary components from your project structure
from ..browser.browser_controller import BrowserController
from ..llm.llm_client import LLMClient # Assuming you have this initialized
from ..dom.views import DOMState # To type hint DOM state
# Configure basic logging for this script
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Generic descriptions for LLM to find elements
USERNAME_FIELD_DESC = "the username input field"
PASSWORD_FIELD_DESC = "the password input field"
SUBMIT_BUTTON_DESC = "the login or submit button"
# Element to verify login success
LOGIN_SUCCESS_SELECTOR_DESC = "the logout button or link" # Description for verification element
# --- Output file path ---
AUTH_STATE_FILE = "auth_state.json"
# ---------------------
# --- Pydantic Schema for LLM Selector Response ---
class LLMSelectorResponse(BaseModel):
selector: Optional[str] = Field(..., description="The best CSS selector found for the described element, or null if not found/identifiable.")
reasoning: str = Field(..., description="Explanation for the chosen selector or why none was found.")
# -----------------------------------------------
# --- Helper Function to Find Selector via LLM ---
def find_element_selector_via_llm(
llm_client: LLMClient,
element_description: str,
dom_state: Optional[DOMState],
page: Any # Playwright Page object for validation
) -> Optional[str]:
"""
Uses LLM to find a selector for a described element based on DOM context.
Validates the selector before returning.
"""
if not llm_client:
logger.error("LLMClient is not available.")
return None
if not dom_state or not dom_state.element_tree:
logger.error(f"Cannot find selector for '{element_description}': DOM state is not available.")
return None
try:
dom_context_str, _ = dom_state.element_tree.generate_llm_context_string(context_purpose='verification')
current_url = page.url if page else "Unknown"
prompt = f"""
You are an AI assistant identifying CSS selectors for web automation.
Based on the following HTML context and the element description, provide the most robust CSS selector.
**Current URL:** {current_url}
**Element to Find:** "{element_description}"
**HTML Context (Visible elements, interactive `[index]`, static `(Static)`):**
```html
{dom_context_str}
\```
**Your Task:**
1. Analyze the HTML context to find the single element that best matches the description "{element_description}".
2. Provide the most stable and specific CSS selector for that element. Prioritize IDs, unique data attributes (like data-testid), or name attributes. Avoid relying solely on text or highly dynamic classes if possible.
3. If no suitable element is found, return null for the selector.
**Output Format:** Respond ONLY with a JSON object matching the following schema:
```json
{{
"selector": "YOUR_SUGGESTED_CSS_SELECTOR_OR_NULL",
"reasoning": "Explain your choice or why none was found."
}}
\```
"""
logger.debug(f"Sending prompt to LLM to find selector for: '{element_description}'")
response_obj = llm_client.generate_json(LLMSelectorResponse, prompt)
if isinstance(response_obj, LLMSelectorResponse):
selector = response_obj.selector
reasoning = response_obj.reasoning
if selector:
logger.info(f"LLM suggested selector '{selector}' for '{element_description}'. Reasoning: {reasoning}")
# --- Validate Selector ---
try:
handles = page.query_selector_all(selector)
count = len(handles)
if count == 1:
logger.info(f"✅ Validation PASSED: Selector '{selector}' uniquely found the element.")
return selector
elif count > 1:
logger.warning(f"⚠️ Validation WARNING: Selector '{selector}' matched {count} elements. Using the first one.")
return selector # Still return it, maybe it's okay
else: # count == 0
logger.error(f"❌ Validation FAILED: Selector '{selector}' did not find any elements.")
return None
except Exception as validate_err:
logger.error(f"❌ Validation ERROR for selector '{selector}': {validate_err}")
return None
# --- End Validation ---
else:
logger.error(f"LLM could not find a selector for '{element_description}'. Reasoning: {reasoning}")
return None
elif isinstance(response_obj, str): # LLM Error string
logger.error(f"LLM returned an error finding selector for '{element_description}': {response_obj}")
return None
else:
logger.error(f"Unexpected response type from LLM finding selector for '{element_description}': {type(response_obj)}")
return None
except Exception as e:
logger.error(f"Error during LLM selector identification for '{element_description}': {e}", exc_info=True)
return None
# --- End Helper Function ---
# --- Main Function ---
def record_selectors_and_save_auth_state(llm_client: LLMClient, login_url: str, auth_state_file: str = AUTH_STATE_FILE):
"""
Uses LLM to find login selectors, gets credentials securely, performs login,
and saves the authentication state.
"""
logger.info("--- Authentication State Generation (Recorder-Assisted Selectors) ---")
if not login_url:
logger.error(f"Login url not provided. Exiting...")
return False
# Get credentials securely first
try:
username = input(f"Enter username (will be visible): ")
if not username: raise ValueError("Username cannot be empty.")
password = getpass.getpass(f"Enter password for '{username}' (input will be hidden): ")
if not password: raise ValueError("Password cannot be empty.")
except (EOFError, ValueError) as e:
logger.error(f"\n❌ Input error: {e}. Aborting.")
return False
except Exception as e:
logger.error(f"\n❌ Error reading input: {e}")
return False
logger.info("Initializing BrowserController (visible browser)...")
# Must run non-headless for user interaction/visibility AND selector validation
browser_controller = BrowserController(headless=False)
final_success = False
try:
browser_controller.start()
page = browser_controller.page
if not page: raise RuntimeError("Failed to initialize browser page.")
logger.info(f"Navigating browser to login page: {login_url}")
browser_controller.goto(login_url)
logger.info("Attempting to identify login form selectors using LLM...")
# Give the page a moment to settle before getting DOM
time.sleep(1)
dom_state = browser_controller.get_structured_dom(highlight_all_clickable_elements=False, viewport_expansion=-1)
# Find Selectors using the helper function
username_selector = find_element_selector_via_llm(llm_client, USERNAME_FIELD_DESC, dom_state, page)
if not username_selector: return False # Abort if not found
password_selector = find_element_selector_via_llm(llm_client, PASSWORD_FIELD_DESC, dom_state, page)
if not password_selector: return False
submit_selector = find_element_selector_via_llm(llm_client, SUBMIT_BUTTON_DESC, dom_state, page)
if not submit_selector: return False
logger.info("Successfully identified all necessary login selectors.")
logger.info(f" Username Field: '{username_selector}'")
logger.info(f" Password Field: '{password_selector}'")
logger.info(f" Submit Button: '{submit_selector}'")
input("\n-> Press Enter to proceed with login using these selectors and your credentials...")
# --- Execute Login (using identified selectors and secure credentials) ---
logger.info(f"Typing username into: {username_selector}")
browser_controller.type(username_selector, username)
time.sleep(0.3)
logger.info(f"Typing password into: {password_selector}")
browser_controller.type(password_selector, password)
time.sleep(0.3)
logger.info(f"Clicking submit button: {submit_selector}")
browser_controller.click(submit_selector)
# --- Verify Login Success ---
logger.info("Attempting to identify login success element selector using LLM...")
# Re-fetch DOM state after potential page change/update
time.sleep(1) # Wait briefly for page update
post_login_dom_state = browser_controller.get_structured_dom(highlight_all_clickable_elements=False, viewport_expansion=-1)
login_success_selector = find_element_selector_via_llm(llm_client, LOGIN_SUCCESS_SELECTOR_DESC, post_login_dom_state, page)
if not login_success_selector:
logger.error("❌ Login Verification Failed: Could not identify the confirmation element via LLM.")
raise RuntimeError("Failed to identify login confirmation element.") # Treat as failure
logger.info(f"Waiting for login confirmation element ({login_success_selector}) to appear...")
try:
page.locator(login_success_selector).wait_for(state="visible", timeout=15000)
logger.info("✅ Login successful! Confirmation element found.")
except PlaywrightTimeoutError:
logger.error(f"❌ Login Failed: Confirmation element '{login_success_selector}' did not appear within timeout.")
raise # Re-raise to be caught by the main handler
# --- Save the storage state ---
if browser_controller.context:
logger.info(f"Saving authentication state to {auth_state_file}...")
browser_controller.context.storage_state(path=auth_state_file)
logger.info(f"✅ Successfully saved authentication state.")
final_success = True
else:
logger.error("❌ Cannot save state: Browser context is not available.")
except (PlaywrightError, ValueError, RuntimeError) as e:
logger.error(f"❌ An error occurred: {type(e).__name__}: {e}", exc_info=False)
if browser_controller and browser_controller.page:
ts = time.strftime("%Y%m%d_%H%M%S")
fail_path = f"output/record_auth_error_{ts}.png"
browser_controller.save_screenshot(fail_path)
logger.info(f"Saved error screenshot to: {fail_path}")
except Exception as e:
logger.critical(f"❌ An unexpected critical error occurred: {e}", exc_info=True)
finally:
logger.info("Closing browser...")
if browser_controller:
browser_controller.close()
return final_success
# --- End Main Function ---
```
--------------------------------------------------------------------------------
/src/security/zap_scanner.py:
--------------------------------------------------------------------------------
```python
# zap_scanner.py
import logging
import subprocess
import os
import shlex
import time
import json
import requests
from datetime import datetime
from .utils import parse_json_file # Relative import
ZAP_TIMEOUT_SECONDS = 1800 # 30 minutes default
ZAP_API_PORT = 8080 # Default ZAP API port
def run_zap_scan(target_url: str, output_dir="results", timeout=ZAP_TIMEOUT_SECONDS,
zap_path=None, api_key=None, scan_mode="baseline"):
"""
Runs OWASP ZAP security scanner against a target URL.
Args:
target_url: The URL to scan
output_dir: Directory to store scan results
timeout: Maximum time in seconds for the scan
zap_path: Path to ZAP installation (uses docker by default)
api_key: ZAP API key if required
scan_mode: Type of scan - 'baseline', 'full' or 'api'
"""
if not target_url:
logging.error("ZAP target URL is required")
return []
logging.info(f"Starting ZAP scan for target: {target_url}")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"zap_output_{timestamp}.json"
output_filepath = os.path.join(output_dir, output_filename)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Determine if using Docker or local ZAP installation
use_docker = zap_path is None
if use_docker:
# Docker command to run ZAP in a container
command = [
"docker", "run", "--rm", "-v", f"{os.path.abspath(output_dir)}:/zap/wrk:rw",
"-t", "owasp/zap2docker-stable", "zap-" + scan_mode + "-scan.py",
"-t", target_url,
"-J", output_filename
]
if api_key:
command.extend(["-z", f"api.key={api_key}"])
else:
# Local ZAP installation
script_name = f"zap-{scan_mode}-scan.py"
command = [
os.path.join(zap_path, script_name),
"-t", target_url,
"-J", output_filepath
]
if api_key:
command.extend(["-z", f"api.key={api_key}"])
logging.debug(f"Executing ZAP command: {' '.join(shlex.quote(cmd) for cmd in command)}")
try:
result = subprocess.run(command, capture_output=True, text=True, timeout=timeout, check=False)
logging.info("ZAP process finished.")
logging.debug(f"ZAP stdout:\n{result.stdout}")
if result.returncode != 0:
logging.warning(f"ZAP exited with non-zero status code: {result.returncode}")
return [f"ZAP exited with non-zero status code: {result.returncode}"]
# For Docker, the output will be in the mapped volume
actual_output_path = output_filepath if not use_docker else os.path.join(output_dir, output_filename)
# Parse the JSON output file
report_data = parse_json_file(actual_output_path)
if report_data and "site" in report_data:
# Process ZAP findings from the report
findings = []
# Structure varies based on scan mode but generally has sites with alerts
for site in report_data.get("site", []):
site_url = site.get("@name", "")
for alert in site.get("alerts", []):
finding = {
'tool': 'OWASP ZAP',
'severity': alert.get("riskdesc", "").split(" ", 1)[0],
'message': alert.get("name", ""),
'description': alert.get("desc", ""),
'url': site_url,
'solution': alert.get("solution", ""),
'references': alert.get("reference", ""),
'cweid': alert.get("cweid", ""),
'instances': len(alert.get("instances", [])),
}
findings.append(finding)
logging.info(f"Successfully parsed {len(findings)} findings from ZAP output.")
return findings
else:
logging.warning(f"Could not parse findings from ZAP output file: {actual_output_path}")
return [f"Could not parse findings from ZAP output file: {actual_output_path}"]
except subprocess.TimeoutExpired:
logging.error(f"ZAP scan timed out after {timeout} seconds.")
return [f"ZAP scan timed out after {timeout} seconds."]
except FileNotFoundError as e:
if use_docker:
logging.error("Docker command not found. Is Docker installed and in PATH?")
return ["Docker command not found. Is Docker installed and in PATH?"]
else:
logging.error(f"ZAP command not found at {zap_path}. Is ZAP installed?")
return [f"ZAP command not found at {zap_path}. Is ZAP installed?"]
except Exception as e:
logging.error(f"An unexpected error occurred while running ZAP: {e}")
return [f"An unexpected error occurred while running ZAP: {e}"]
def run_zap_api_scan(target_url: str, api_definition: str, output_dir="results",
timeout=ZAP_TIMEOUT_SECONDS, zap_path=None, api_key=None):
"""
Runs ZAP API scan against a REST API with OpenAPI/Swagger definition.
Args:
target_url: Base URL of the API
api_definition: Path to OpenAPI/Swagger definition file
output_dir: Directory to store scan results
timeout: Maximum time in seconds for the scan
zap_path: Path to ZAP installation (uses docker by default)
api_key: ZAP API key if required
"""
if not os.path.isfile(api_definition):
logging.error(f"API definition file not found: {api_definition}")
return []
# Similar implementation as run_zap_scan but with API scanning options
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"zap_api_output_{timestamp}.json"
output_filepath = os.path.join(output_dir, output_filename)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
use_docker = zap_path is None
if use_docker:
# Volume mount for API definition file
api_def_dir = os.path.dirname(os.path.abspath(api_definition))
api_def_file = os.path.basename(api_definition)
command = [
"docker", "run", "--rm",
"-v", f"{os.path.abspath(output_dir)}:/zap/wrk:rw",
"-v", f"{api_def_dir}:/zap/api:ro",
"-t", "owasp/zap2docker-stable", "zap-api-scan.py",
"-t", target_url,
"-f", f"/zap/api/{api_def_file}",
"-J", output_filename
]
else:
command = [
os.path.join(zap_path, "zap-api-scan.py"),
"-t", target_url,
"-f", api_definition,
"-J", output_filepath
]
if api_key:
command.extend(["-z", f"api.key={api_key}"])
logging.debug(f"Executing ZAP API scan command: {' '.join(shlex.quote(cmd) for cmd in command)}")
# The rest of the implementation follows similar pattern to run_zap_scan
try:
result = subprocess.run(command, capture_output=True, text=True, timeout=timeout, check=False)
# Processing similar to run_zap_scan
logging.info("ZAP API scan process finished.")
logging.debug(f"ZAP stdout:\n{result.stdout}")
if result.returncode != 0:
logging.warning(f"ZAP API scan exited with non-zero status code: {result.returncode}")
return [f"ZAP API scan exited with non-zero status code: {result.returncode}"]
# For Docker, the output will be in the mapped volume
actual_output_path = output_filepath if not use_docker else os.path.join(output_dir, output_filename)
# Parse the JSON output file - same processing as run_zap_scan
report_data = parse_json_file(actual_output_path)
if report_data and "site" in report_data:
findings = []
for site in report_data.get("site", []):
site_url = site.get("@name", "")
for alert in site.get("alerts", []):
finding = {
'tool': 'OWASP ZAP API Scan',
'severity': alert.get("riskdesc", "").split(" ", 1)[0],
'message': alert.get("name", ""),
'description': alert.get("desc", ""),
'url': site_url,
'solution': alert.get("solution", ""),
'references': alert.get("reference", ""),
'cweid': alert.get("cweid", ""),
'instances': len(alert.get("instances", [])),
}
findings.append(finding)
logging.info(f"Successfully parsed {len(findings)} findings from ZAP API scan output.")
return findings
else:
logging.warning(f"Could not parse findings from ZAP API scan output file: {actual_output_path}")
return [f"Could not parse findings from ZAP API scan output file: {actual_output_path}"]
except Exception as e:
logging.error(f"An unexpected error occurred while running ZAP API scan: {e}")
return [f"An unexpected error occurred while running ZAP API scan: {e}"]
def discover_endpoints(target_url: str, output_dir="results", timeout=600, zap_path=None, api_key=None):
"""
Uses ZAP's spider to discover endpoints in a web application.
Args:
target_url: The URL to scan
output_dir: Directory to store results
timeout: Maximum time in seconds for the spider
zap_path: Path to ZAP installation (uses docker by default)
api_key: ZAP API key if required
Returns:
List of discovered endpoints and their details
"""
if not target_url:
logging.error("Target URL is required for endpoint discovery")
return []
logging.info(f"Starting ZAP endpoint discovery for: {target_url}")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"zap_endpoints_{timestamp}.json"
output_filepath = os.path.join(output_dir, output_filename)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
use_docker = zap_path is None
if use_docker:
command = [
"docker", "run", "--rm",
"-v", f"{os.path.abspath(output_dir)}:/zap/wrk:rw",
"-t", "owasp/zap2docker-stable",
"zap-full-scan.py",
"-t", target_url,
"-J", output_filename,
"-z", "-config spider.maxDuration=1", # Limit spider duration
"--spider-first", # Run spider before the scan
"-n", "endpoints.context" # Don't perform actual scan, just spider
]
else:
command = [
os.path.join(zap_path, "zap-full-scan.py"),
"-t", target_url,
"-J", output_filepath,
"-z", "-config spider.maxDuration=1",
"--spider-first",
"-n", "endpoints.context"
]
if api_key:
command.extend(["-z", f"api.key={api_key}"])
logging.debug(f"Executing ZAP endpoint discovery: {' '.join(shlex.quote(cmd) for cmd in command)}")
try:
result = subprocess.run(command, capture_output=True, text=True, timeout=timeout, check=False)
logging.info("ZAP endpoint discovery finished.")
logging.debug(f"ZAP stdout:\n{result.stdout}")
actual_output_path = output_filepath if not use_docker else os.path.join(output_dir, output_filename)
# Parse the JSON output file
report_data = parse_json_file(actual_output_path)
if report_data:
endpoints = []
# Extract endpoints from spider results
if "site" in report_data:
for site in report_data.get("site", []):
site_url = site.get("@name", "")
# Extract URLs from alerts and spider results
urls = set()
# Get URLs from alerts
for alert in site.get("alerts", []):
for instance in alert.get("instances", []):
url = instance.get("uri", "")
if url:
urls.add(url)
# Add discovered endpoints
for url in urls:
endpoint = {
'url': url,
'method': 'GET', # Default to GET, ZAP spider mainly discovers GET endpoints
'source': 'ZAP Spider',
'parameters': [], # Could be enhanced to parse URL parameters
'discovered_at': datetime.now().isoformat()
}
endpoints.append(endpoint)
logging.info(f"Successfully discovered {len(endpoints)} endpoints.")
# Save endpoints to a separate file
endpoints_file = os.path.join(output_dir, f"discovered_endpoints_{timestamp}.json")
with open(endpoints_file, 'w') as f:
json.dump(endpoints, f, indent=2)
logging.info(f"Saved discovered endpoints to: {endpoints_file}")
return endpoints
else:
logging.warning("No endpoints discovered or parsing failed.")
return []
except subprocess.TimeoutExpired:
logging.error(f"Endpoint discovery timed out after {timeout} seconds.")
return []
except Exception as e:
logging.error(f"An error occurred during endpoint discovery: {e}")
return []
```
--------------------------------------------------------------------------------
/src/llm/clients/openai_client.py:
--------------------------------------------------------------------------------
```python
# /src/llm/clients/openai_client.py
from PIL import Image
import io
import logging
import time # Import time module
import threading # Import threading for lock
from typing import Type, Optional, Union, List, Dict, Any
logger = logging.getLogger(__name__)
import base64
import json
from ...utils.utils import load_api_key, load_api_base_url, load_api_version, load_llm_model
# --- Provider Specific Imports ---
try:
import openai
from openai import OpenAI
from pydantic import BaseModel # Needed for LLM JSON tool definition
OPENAI_SDK = True
except ImportError:
OPENAI_SDK = False
# Define dummy classes if LLM libs are not installed to avoid NameErrors
class BaseModel: pass
class OpenAI: pass
# --- Helper Function ---
def _image_bytes_to_base64_url(image_bytes: bytes) -> Optional[str]:
"""Converts image bytes to a base64 data URL."""
try:
# Try to determine the image format
img = Image.open(io.BytesIO(image_bytes))
format = img.format
if not format:
logger.warning("Could not determine image format, assuming JPEG.")
format = "jpeg" # Default assumption
else:
format = format.lower()
if format == 'jpg': # Standardize to jpeg
format = 'jpeg'
# Ensure format is supported (common web formats)
if format not in ['jpeg', 'png', 'gif', 'webp']:
logger.warning(f"Unsupported image format '{format}' for base64 URL, defaulting to JPEG.")
format = 'jpeg' # Fallback
encoded_string = base64.b64encode(image_bytes).decode('utf-8')
return f"data:image/{format};base64,{encoded_string}"
except Exception as e:
logger.error(f"Error converting image bytes to base64 URL: {e}", exc_info=True)
return None
class OpenAIClient:
def __init__(self):
self.client = None
self.LLM_api_key = load_api_key()
self.LLM_api_version = load_api_version()
self.LLM_model_name = load_llm_model()
self.LLM_endpoint = load_api_base_url()
self.LLM_vision_model_name = self.LLM_model_name
if not OPENAI_SDK:
raise ImportError("LLM OpenAI libraries (openai, pydantic) are not installed. Please install them.")
if not all([self.LLM_api_key, self.LLM_endpoint, self.LLM_api_version, self.LLM_model_name]):
raise ValueError("LLM_api_key, LLM_endpoint, LLM_api_version, and LLM_model_name are required for provider 'LLM'")
try:
self.client = OpenAI(
api_key=self.LLM_api_key,
base_url=self.LLM_endpoint,
)
# Test connection slightly by listing models (optional, requires different permission potentially)
# self.client.models.list()
logger.info(f"LLM OpenAI Client initialized for endpoint {self.LLM_endpoint} and model {self.LLM_model_name}.")
except Exception as e:
logger.error(f"Failed to initialize LLM OpenAI Client: {e}", exc_info=True)
raise RuntimeError(f"LLM client initialization failed: {e}")
def generate_text(self, prompt: str) -> str:
try:
log_prompt = prompt[:200] + ('...' if len(prompt) > 200 else '')
logger.debug(f"[LLM] Sending text prompt (truncated): {log_prompt}")
messages = [{"role": "user", "content": prompt}]
response = self.client.chat.completions.create(
model=self.LLM_model_name,
messages=messages,
max_tokens=1024, # Adjust as needed
)
logger.debug("[LLM] Received text response.")
if response.choices:
message = response.choices[0].message
if message.content:
return message.content
else:
# Handle cases like function calls if they unexpectedly occur or content filter
finish_reason = response.choices[0].finish_reason
logger.warning(f"[LLM] Text generation returned no content. Finish reason: {finish_reason}. Response: {response.model_dump_json(indent=2)}")
if finish_reason == 'content_filter':
return "Error: [LLM] Content generation blocked due to content filter."
return "Error: [LLM] Empty response from LLM."
else:
logger.warning(f"[LLM] Text generation returned no choices. Response: {response.model_dump_json(indent=2)}")
return "Error: [LLM] No choices returned from LLM."
except openai.APIError as e:
# Handle API error here, e.g. retry or log
logger.error(f"[LLM] OpenAI API returned an API Error: {e}", exc_info=True)
return f"Error: [LLM] API Error - {type(e).__name__}: {e}"
except openai.AuthenticationError as e:
logger.error(f"[LLM] OpenAI API authentication error: {e}", exc_info=True)
return f"Error: [LLM] Authentication Error - {e}"
except openai.RateLimitError as e:
logger.error(f"[LLM] OpenAI API request exceeded rate limit: {e}", exc_info=True)
# Note: Our simple time.sleep might not be enough for LLM's complex limits
return f"Error: [LLM] Rate limit exceeded - {e}"
except Exception as e:
logger.error(f"Error during LLM text generation: {e}", exc_info=True)
return f"Error: [LLM] Failed to communicate with API - {type(e).__name__}: {e}"
def generate_multimodal(self, prompt: str, image_bytes: bytes) -> str:
if not self.LLM_vision_model_name:
logger.error("[LLM] LLM vision model name not configured.")
return "Error: [LLM] Vision model not configured."
base64_url = _image_bytes_to_base64_url(image_bytes)
if not base64_url:
return "Error: [LLM] Failed to convert image to base64."
try:
log_prompt = prompt[:200] + ('...' if len(prompt) > 200 else '')
logger.debug(f"[LLM] Sending multimodal prompt (truncated): {log_prompt} with image.")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": base64_url}},
],
}
]
response = self.client.chat.completions.create(
model=self.LLM_vision_model_name, # Use the vision model deployment
messages=messages,
max_tokens=1024, # Adjust as needed
)
logger.debug("[LLM] Received multimodal response.")
# Parsing logic similar to text generation
if response.choices:
message = response.choices[0].message
if message.content:
return message.content
else:
finish_reason = response.choices[0].finish_reason
logger.warning(f"[LLM] Multimodal generation returned no content. Finish reason: {finish_reason}. Response: {response.model_dump_json(indent=2)}")
if finish_reason == 'content_filter':
return "Error: [LLM] Content generation blocked due to content filter."
return "Error: [LLM] Empty multimodal response from LLM."
else:
logger.warning(f"[LLM] Multimodal generation returned no choices. Response: {response.model_dump_json(indent=2)}")
return "Error: [LLM] No choices returned from Vision LLM."
except openai.APIError as e:
logger.error(f"[LLM] OpenAI Vision API returned an API Error: {e}", exc_info=True)
return f"Error: [LLM] Vision API Error - {type(e).__name__}: {e}"
# Add other specific openai exceptions as needed (AuthenticationError, RateLimitError, etc.)
except Exception as e:
logger.error(f"Error during LLM multimodal generation: {e}", exc_info=True)
return f"Error: [LLM] Failed to communicate with Vision API - {type(e).__name__}: {e}"
def generate_json(self, Schema_Class: Type[BaseModel], prompt: str, image_bytes: Optional[bytes] = None) -> Union[Dict[str, Any], str]:
if not issubclass(Schema_Class, BaseModel):
logger.error(f"[LLM] Schema_Class must be a Pydantic BaseModel for LLM JSON generation.")
return "Error: [LLM] Invalid schema type provided."
current_model = self.LLM_model_name
messages: List[Dict[str, Any]] = [{"role": "user", "content": []}] # Initialize user content as list
# Prepare content (text and optional image)
text_content = {"type": "text", "text": prompt}
messages[0]["content"].append(text_content) # type: ignore
log_msg_suffix = ""
if image_bytes is not None:
if not self.LLM_vision_model_name:
logger.error("[LLM] LLM vision model name not configured for multimodal JSON.")
return "Error: [LLM] Vision model not configured for multimodal JSON."
current_model = self.LLM_vision_model_name # Use vision model if image is present
base64_url = _image_bytes_to_base64_url(image_bytes)
if not base64_url:
return "Error: [LLM] Failed to convert image to base64 for JSON."
image_content = {"type": "image_url", "image_url": {"url": base64_url}}
messages[0]["content"].append(image_content) # type: ignore
log_msg_suffix = " with image"
# Prepare the tool based on the Pydantic schema
try:
tool_def = openai.pydantic_function_tool(Schema_Class)
tools = [tool_def]
# Tool choice can force the model to use the function, or let it decide.
# Forcing it: tool_choice = {"type": "function", "function": {"name": Schema_Class.__name__}}
# Letting it decide (often better unless you *know* it must be called): tool_choice = "auto"
# Let's explicitly request the tool for structured output
tool_choice = {"type": "function", "function": {"name": tool_def['function']['name']}}
except Exception as tool_err:
logger.error(f"[LLM] Failed to create tool definition from schema {Schema_Class.__name__}: {tool_err}", exc_info=True)
return f"Error: [LLM] Failed to create tool definition - {tool_err}"
try:
log_prompt = prompt[:200] + ('...' if len(prompt) > 200 else '')
logger.debug(f"[LLM] Sending JSON prompt (truncated): {log_prompt}{log_msg_suffix} with schema {Schema_Class.__name__}")
# Add a system prompt to guide the model (optional but helpful)
system_message = {"role": "system", "content": f"You are a helpful assistant. Use the provided '{Schema_Class.__name__}' tool to structure your response based on the user's request."}
final_messages = [system_message] + messages
response = self.client.chat.completions.create(
model=current_model, # Use vision model if image included
messages=final_messages,
tools=tools,
tool_choice=tool_choice, # Request the specific tool
max_tokens=2048, # Adjust as needed
)
logger.debug("[LLM] Received JSON response structure.")
if response.choices:
message = response.choices[0].message
finish_reason = response.choices[0].finish_reason
if message.tool_calls:
if len(message.tool_calls) > 1:
logger.warning(f"[LLM] Multiple tool calls received, using the first one for schema {Schema_Class.__name__}")
tool_call = message.tool_calls[0]
if tool_call.type == 'function' and tool_call.function.name == tool_def['function']['name']:
function_args_str = tool_call.function.arguments
try:
# Parse the arguments string into a dictionary
parsed_args = json.loads(function_args_str)
# Validate and potentially instantiate the Pydantic model
model_instance = Schema_Class.model_validate(parsed_args)
return model_instance # Return as dict
# print(parsed_args)
# return parsed_args # Return the parsed dict directly
except json.JSONDecodeError as json_err:
logger.error(f"[LLM] Failed to parse JSON arguments from tool call: {json_err}. Arguments: '{function_args_str}'")
return f"Error: [LLM] Failed to parse JSON arguments - {json_err}"
except Exception as val_err: # Catch Pydantic validation errors if model_validate is used
logger.error(f"[LLM] JSON arguments failed validation for schema {Schema_Class.__name__}: {val_err}. Arguments: {function_args_str}")
return f"Error: [LLM] JSON arguments failed validation - {val_err}"
else:
logger.warning(f"[LLM] Expected function tool call for {Schema_Class.__name__} but got type '{tool_call.type}' or name '{tool_call.function.name}'.")
return f"Error: [LLM] Unexpected tool call type/name received."
elif finish_reason == 'tool_calls':
# This might happen if the model intended to call but failed, or structure is odd
logger.warning(f"[LLM] Finish reason is 'tool_calls' but no tool_calls found in message. Response: {response.model_dump_json(indent=2)}")
return "Error: [LLM] Model indicated tool use but none found."
elif finish_reason == 'content_filter':
logger.warning(f"[LLM] JSON generation blocked due to content filter.")
return "Error: [LLM] Content generation blocked due to content filter."
else:
# Model didn't use the tool
logger.warning(f"[LLM] Model did not use the requested JSON tool {Schema_Class.__name__}. Finish reason: {finish_reason}. Content: {message.content}")
# You might return the text content or an error depending on requirements
# return message.content or "Error: [LLM] Model generated text instead of using the JSON tool."
return f"Error: [LLM] Model did not use the JSON tool. Finish Reason: {finish_reason}."
else:
logger.warning(f"[LLM] JSON generation returned no choices. Response: {response.model_dump_json(indent=2)}")
return "Error: [LLM] No choices returned from LLM for JSON request."
except openai.APIError as e:
logger.error(f"[LLM] OpenAI API returned an API Error during JSON generation: {e}", exc_info=True)
return f"Error: [LLM] API Error (JSON) - {type(e).__name__}: {e}"
# Add other specific openai exceptions (AuthenticationError, RateLimitError, etc.)
except Exception as e:
logger.error(f"Error during LLM JSON generation: {e}", exc_info=True)
return f"Error: [LLM] Failed to communicate with API for JSON - {type(e).__name__}: {e}"
```
--------------------------------------------------------------------------------
/src/llm/clients/azure_openai_client.py:
--------------------------------------------------------------------------------
```python
# /src/llm/clients/azure_openai_client.py
from PIL import Image
import io
import logging
import time # Import time module
import threading # Import threading for lock
from typing import Type, Optional, Union, List, Dict, Any
logger = logging.getLogger(__name__)
import base64
import json
from ...utils.utils import load_api_key, load_api_base_url, load_api_version, load_llm_model
# --- Provider Specific Imports ---
try:
import openai
from openai import AzureOpenAI
from pydantic import BaseModel # Needed for LLM JSON tool definition
OPENAI_SDK = True
except ImportError:
OPENAI_SDK = False
# Define dummy classes if LLM libs are not installed to avoid NameErrors
class BaseModel: pass
class OpenAI: pass
# --- Helper Function ---
def _image_bytes_to_base64_url(image_bytes: bytes) -> Optional[str]:
"""Converts image bytes to a base64 data URL."""
try:
# Try to determine the image format
img = Image.open(io.BytesIO(image_bytes))
format = img.format
if not format:
logger.warning("Could not determine image format, assuming JPEG.")
format = "jpeg" # Default assumption
else:
format = format.lower()
if format == 'jpg': # Standardize to jpeg
format = 'jpeg'
# Ensure format is supported (common web formats)
if format not in ['jpeg', 'png', 'gif', 'webp']:
logger.warning(f"Unsupported image format '{format}' for base64 URL, defaulting to JPEG.")
format = 'jpeg' # Fallback
encoded_string = base64.b64encode(image_bytes).decode('utf-8')
return f"data:image/{format};base64,{encoded_string}"
except Exception as e:
logger.error(f"Error converting image bytes to base64 URL: {e}", exc_info=True)
return None
class AzureOpenAIClient:
def __init__(self):
self.client = None
self.LLM_api_key = load_api_key()
self.LLM_api_version = load_api_version()
self.LLM_model_name = load_llm_model()
self.LLM_endpoint = load_api_base_url()
self.LLM_vision_model_name = self.LLM_model_name
if not OPENAI_SDK:
raise ImportError("LLM OpenAI libraries (openai, pydantic) are not installed. Please install them.")
if not all([self.LLM_api_key, self.LLM_endpoint, self.LLM_api_version, self.LLM_model_name]):
raise ValueError("LLM_api_key, LLM_endpoint, LLM_api_version, and LLM_model_name are required for provider 'LLM'")
try:
self.client = AzureOpenAI(
api_key=self.LLM_api_key,
azure_endpoint=self.LLM_endpoint,
api_version=self.LLM_api_version
)
# Test connection slightly by listing models (optional, requires different permission potentially)
# self.client.models.list()
logger.info(f"LLM OpenAI Client initialized for endpoint {self.LLM_endpoint} and model {self.LLM_model_name}.")
except Exception as e:
logger.error(f"Failed to initialize LLM OpenAI Client: {e}", exc_info=True)
raise RuntimeError(f"LLM client initialization failed: {e}")
def generate_text(self, prompt: str) -> str:
try:
log_prompt = prompt[:200] + ('...' if len(prompt) > 200 else '')
logger.debug(f"[LLM] Sending text prompt (truncated): {log_prompt}")
messages = [{"role": "user", "content": prompt}]
response = self.client.chat.completions.create(
model=self.LLM_model_name,
messages=messages,
max_tokens=1024, # Adjust as needed
)
logger.debug("[LLM] Received text response.")
if response.choices:
message = response.choices[0].message
if message.content:
return message.content
else:
# Handle cases like function calls if they unexpectedly occur or content filter
finish_reason = response.choices[0].finish_reason
logger.warning(f"[LLM] Text generation returned no content. Finish reason: {finish_reason}. Response: {response.model_dump_json(indent=2)}")
if finish_reason == 'content_filter':
return "Error: [LLM] Content generation blocked due to content filter."
return "Error: [LLM] Empty response from LLM."
else:
logger.warning(f"[LLM] Text generation returned no choices. Response: {response.model_dump_json(indent=2)}")
return "Error: [LLM] No choices returned from LLM."
except openai.APIError as e:
# Handle API error here, e.g. retry or log
logger.error(f"[LLM] OpenAI API returned an API Error: {e}", exc_info=True)
return f"Error: [LLM] API Error - {type(e).__name__}: {e}"
except openai.AuthenticationError as e:
logger.error(f"[LLM] OpenAI API authentication error: {e}", exc_info=True)
return f"Error: [LLM] Authentication Error - {e}"
except openai.RateLimitError as e:
logger.error(f"[LLM] OpenAI API request exceeded rate limit: {e}", exc_info=True)
# Note: Our simple time.sleep might not be enough for LLM's complex limits
return f"Error: [LLM] Rate limit exceeded - {e}"
except Exception as e:
logger.error(f"Error during LLM text generation: {e}", exc_info=True)
return f"Error: [LLM] Failed to communicate with API - {type(e).__name__}: {e}"
def generate_multimodal(self, prompt: str, image_bytes: bytes) -> str:
if not self.LLM_vision_model_name:
logger.error("[LLM] LLM vision model name not configured.")
return "Error: [LLM] Vision model not configured."
base64_url = _image_bytes_to_base64_url(image_bytes)
if not base64_url:
return "Error: [LLM] Failed to convert image to base64."
try:
log_prompt = prompt[:200] + ('...' if len(prompt) > 200 else '')
logger.debug(f"[LLM] Sending multimodal prompt (truncated): {log_prompt} with image.")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": base64_url}},
],
}
]
response = self.client.chat.completions.create(
model=self.LLM_vision_model_name, # Use the vision model deployment
messages=messages,
max_tokens=1024, # Adjust as needed
)
logger.debug("[LLM] Received multimodal response.")
# Parsing logic similar to text generation
if response.choices:
message = response.choices[0].message
if message.content:
return message.content
else:
finish_reason = response.choices[0].finish_reason
logger.warning(f"[LLM] Multimodal generation returned no content. Finish reason: {finish_reason}. Response: {response.model_dump_json(indent=2)}")
if finish_reason == 'content_filter':
return "Error: [LLM] Content generation blocked due to content filter."
return "Error: [LLM] Empty multimodal response from LLM."
else:
logger.warning(f"[LLM] Multimodal generation returned no choices. Response: {response.model_dump_json(indent=2)}")
return "Error: [LLM] No choices returned from Vision LLM."
except openai.APIError as e:
logger.error(f"[LLM] OpenAI Vision API returned an API Error: {e}", exc_info=True)
return f"Error: [LLM] Vision API Error - {type(e).__name__}: {e}"
# Add other specific openai exceptions as needed (AuthenticationError, RateLimitError, etc.)
except Exception as e:
logger.error(f"Error during LLM multimodal generation: {e}", exc_info=True)
return f"Error: [LLM] Failed to communicate with Vision API - {type(e).__name__}: {e}"
def generate_json(self, Schema_Class: Type[BaseModel], prompt: str, image_bytes: Optional[bytes] = None) -> Union[Dict[str, Any], str]:
if not issubclass(Schema_Class, BaseModel):
logger.error(f"[LLM] Schema_Class must be a Pydantic BaseModel for LLM JSON generation.")
return "Error: [LLM] Invalid schema type provided."
current_model = self.LLM_model_name
messages: List[Dict[str, Any]] = [{"role": "user", "content": []}] # Initialize user content as list
# Prepare content (text and optional image)
text_content = {"type": "text", "text": prompt}
messages[0]["content"].append(text_content) # type: ignore
log_msg_suffix = ""
if image_bytes is not None:
if not self.LLM_vision_model_name:
logger.error("[LLM] LLM vision model name not configured for multimodal JSON.")
return "Error: [LLM] Vision model not configured for multimodal JSON."
current_model = self.LLM_vision_model_name # Use vision model if image is present
base64_url = _image_bytes_to_base64_url(image_bytes)
if not base64_url:
return "Error: [LLM] Failed to convert image to base64 for JSON."
image_content = {"type": "image_url", "image_url": {"url": base64_url}}
messages[0]["content"].append(image_content) # type: ignore
log_msg_suffix = " with image"
# Prepare the tool based on the Pydantic schema
try:
tool_def = openai.pydantic_function_tool(Schema_Class)
tools = [tool_def]
# Tool choice can force the model to use the function, or let it decide.
# Forcing it: tool_choice = {"type": "function", "function": {"name": Schema_Class.__name__}}
# Letting it decide (often better unless you *know* it must be called): tool_choice = "auto"
# Let's explicitly request the tool for structured output
tool_choice = {"type": "function", "function": {"name": tool_def['function']['name']}}
except Exception as tool_err:
logger.error(f"[LLM] Failed to create tool definition from schema {Schema_Class.__name__}: {tool_err}", exc_info=True)
return f"Error: [LLM] Failed to create tool definition - {tool_err}"
try:
log_prompt = prompt[:200] + ('...' if len(prompt) > 200 else '')
logger.debug(f"[LLM] Sending JSON prompt (truncated): {log_prompt}{log_msg_suffix} with schema {Schema_Class.__name__}")
# Add a system prompt to guide the model (optional but helpful)
system_message = {"role": "system", "content": f"You are a helpful assistant. Use the provided '{Schema_Class.__name__}' tool to structure your response based on the user's request."}
final_messages = [system_message] + messages
response = self.client.chat.completions.create(
model=current_model, # Use vision model if image included
messages=final_messages,
tools=tools,
tool_choice=tool_choice, # Request the specific tool
max_tokens=2048, # Adjust as needed
)
logger.debug("[LLM] Received JSON response structure.")
if response.choices:
message = response.choices[0].message
finish_reason = response.choices[0].finish_reason
if message.tool_calls:
if len(message.tool_calls) > 1:
logger.warning(f"[LLM] Multiple tool calls received, using the first one for schema {Schema_Class.__name__}")
tool_call = message.tool_calls[0]
if tool_call.type == 'function' and tool_call.function.name == tool_def['function']['name']:
function_args_str = tool_call.function.arguments
try:
# Parse the arguments string into a dictionary
parsed_args = json.loads(function_args_str)
# Validate and potentially instantiate the Pydantic model
model_instance = Schema_Class.model_validate(parsed_args)
return model_instance # Return as dict
# print(parsed_args)
# return parsed_args # Return the parsed dict directly
except json.JSONDecodeError as json_err:
logger.error(f"[LLM] Failed to parse JSON arguments from tool call: {json_err}. Arguments: '{function_args_str}'")
return f"Error: [LLM] Failed to parse JSON arguments - {json_err}"
except Exception as val_err: # Catch Pydantic validation errors if model_validate is used
logger.error(f"[LLM] JSON arguments failed validation for schema {Schema_Class.__name__}: {val_err}. Arguments: {function_args_str}")
return f"Error: [LLM] JSON arguments failed validation - {val_err}"
else:
logger.warning(f"[LLM] Expected function tool call for {Schema_Class.__name__} but got type '{tool_call.type}' or name '{tool_call.function.name}'.")
return f"Error: [LLM] Unexpected tool call type/name received."
elif finish_reason == 'tool_calls':
# This might happen if the model intended to call but failed, or structure is odd
logger.warning(f"[LLM] Finish reason is 'tool_calls' but no tool_calls found in message. Response: {response.model_dump_json(indent=2)}")
return "Error: [LLM] Model indicated tool use but none found."
elif finish_reason == 'content_filter':
logger.warning(f"[LLM] JSON generation blocked due to content filter.")
return "Error: [LLM] Content generation blocked due to content filter."
else:
# Model didn't use the tool
logger.warning(f"[LLM] Model did not use the requested JSON tool {Schema_Class.__name__}. Finish reason: {finish_reason}. Content: {message.content}")
# You might return the text content or an error depending on requirements
# return message.content or "Error: [LLM] Model generated text instead of using the JSON tool."
return f"Error: [LLM] Model did not use the JSON tool. Finish Reason: {finish_reason}."
else:
logger.warning(f"[LLM] JSON generation returned no choices. Response: {response.model_dump_json(indent=2)}")
return "Error: [LLM] No choices returned from LLM for JSON request."
except openai.APIError as e:
logger.error(f"[LLM] OpenAI API returned an API Error during JSON generation: {e}", exc_info=True)
return f"Error: [LLM] API Error (JSON) - {type(e).__name__}: {e}"
# Add other specific openai exceptions (AuthenticationError, RateLimitError, etc.)
except Exception as e:
logger.error(f"Error during LLM JSON generation: {e}", exc_info=True)
return f"Error: [LLM] Failed to communicate with API for JSON - {type(e).__name__}: {e}"
```
--------------------------------------------------------------------------------
/mcp_server.py:
--------------------------------------------------------------------------------
```python
# mcp_server.py
import sys
import os
import json
import logging
from typing import List, Dict, Any, Optional
import asyncio
import re
import time
from datetime import datetime
# Ensure agent modules are importable (adjust path if necessary)
# Assuming mcp_server.py is at the root level alongside agent.py etc.
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.prompts import base as mcp_prompts
# Import necessary components from your existing code
from src.agents.recorder_agent import WebAgent # Needs refactoring for non-interactive use
from src.agents.crawler_agent import CrawlerAgent
from src.llm.llm_client import LLMClient
from src.execution.executor import TestExecutor
from src.utils.utils import load_api_key, load_api_base_url, load_api_version, load_llm_model
from src.security.semgrep_scanner import run_semgrep
from src.security.zap_scanner import run_zap_scan, discover_endpoints
from src.security.nuclei_scanner import run_nuclei
from src.security.utils import save_report
# Configure logging for the MCP server
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - [MCP Server] %(message)s')
logger = logging.getLogger(__name__)
# Define the output directory for tests (consistent with agent/executor)
TEST_OUTPUT_DIR = "output"
# --- Initialize FastMCP Server ---
mcp = FastMCP("WebTestAgentServer")
llm_client = LLMClient(provider='azure')
# --- MCP Tool: Record a New Test Flow (Automated - Requires Agent Refactoring) ---
@mcp.tool()
async def record_test_flow(feature_description: str, project_directory: str, headless: bool = True) -> Dict[str, Any]:
"""
Attempts to automatically record a web test flow based on a natural language description. If a case fails, there might be a possibility that you missed/told wrong step in feature description. Don't give vague actions like select anything. Give exact actions like select so and so element
Uses the WebAgent in automated mode (bypasses interactive prompts). Do not skip telling any step. Give complete end to end steps what to do and what to verify
Args:
feature_description: A natural language description of the test case or user flow. Crucially, this description MUST explicitly include the starting URL of the website to be tested (e.g., 'Go to https://example.com, then click...'). Do not give blanket things for input. Say exact things like enter invalid-email into email box or enter [email protected] into mailbox
project_directory: The project directory you are currently working in. This is used to identify the test flows of a project
headless: Run the underlying browser in headless mode. Defaults to True.
Returns:
A dictionary containing the recording status, including success/failure,
message, and the path to the generated test JSON file if successful.
"""
logger.info(f"Received automated request to record test flow: '{feature_description[:100]}...' (Headless: {headless})")
try:
# 1. Instantiate WebAgent in AUTOMATED mode
recorder_agent = WebAgent(
llm_client=llm_client,
headless=headless, # Allow MCP tool to specify headless
is_recorder_mode=True,
automated_mode=True, # <<< Set automated mode
max_retries_per_subtask=2,
filename=re.sub(r"[ /]", "_", project_directory)
)
# Run the blocking recorder_agent.record method in a separate thread
# Pass the method and its arguments to asyncio.to_thread
logger.info("Delegating agent recording to a separate thread...")
recording_result = await asyncio.to_thread(recorder_agent.record, feature_description)
logger.info(f"Automated recording finished (thread returned). Result: {recording_result}")
return recording_result
except Exception as e:
logger.error(f"Error in record_test_flow tool: {e}", exc_info=True)
return {"success": False, "message": f"Internal server error during automated recording: {e}"}
# --- MCP Tool: Run a Single Regression Test ---
@mcp.tool()
async def run_regression_test(test_file_path: str, headless: bool = True, enable_healing: bool = True, healing_mode: str = 'soft', get_performance: bool = False, get_network_requests: bool = False) -> Dict[str, Any]:
"""
Runs a previously recorded test case from a JSON file. If a case fails, it could be either because your code has a problem, or could be you missed/wrong step in feature description
Args:
test_file_path: The relative or absolute path to the .json test file (e.g., 'output/test_login.json').
headless: Run the browser in headless mode (no visible window). Defaults to True.
enable_healing: Whether to run this regression test with healing mode enabled. In healing mode, if test fails because of a changed or flaky selector, the agent can try to heal the test automatically.
healing_mode: can be 'soft' or 'hard'. In soft mode, only single step is attempted to heal. In hard healing, complete test is tried to be re-recorded
get_performance: Whether to include performance stats in response
get_network_requests: Whether to include network stats in response
Returns:
A dictionary containing the execution result summary, including status (PASS/FAIL),
duration, message, error details (if failed), and evidence paths.
"""
logger.info(f"Received request to run regression test: '{test_file_path}', Headless: {headless}")
# Basic path validation (relative to server or absolute)
if not os.path.isabs(test_file_path):
# Assume relative to the server's working directory or a known output dir
# For simplicity, let's check relative to CWD and TEST_OUTPUT_DIR
potential_paths = [
test_file_path,
os.path.join(TEST_OUTPUT_DIR, test_file_path)
]
found_path = None
for p in potential_paths:
if os.path.exists(p) and os.path.isfile(p):
found_path = p
break
if not found_path:
logger.error(f"Test file not found at '{test_file_path}' or within '{TEST_OUTPUT_DIR}'.")
return {"success": False, "status": "ERROR", "message": f"Test file not found: {test_file_path}"}
test_file_path = os.path.abspath(found_path) # Use absolute path for executor
logger.info(f"Resolved test file path to: {test_file_path}")
try:
# Executor doesn't need the LLM client
executor = TestExecutor(
headless=headless,
llm_client=llm_client,
enable_healing=enable_healing,
healing_mode=healing_mode,
get_network_requests=get_network_requests,
get_performance=get_performance
)
logger.info(f"Delegating test execution for '{test_file_path}' to a separate thread...")
test_result = await asyncio.to_thread(
executor.run_test, # The function to run
test_file_path # Arguments for the function
)
# Add a success flag for generic tool success/failure indication
# Post-processing (synchronous)
test_result["success"] = test_result.get("status") == "PASS"
logger.info(f"Execution finished for '{test_file_path}' (thread returned). Status: {test_result.get('status')}")
try:
base_name = os.path.splitext(os.path.basename(test_file_path))[0]
result_filename = os.path.join("output", f"execution_result_{base_name}_{time.strftime('%Y%m%d_%H%M%S')}.json")
with open(result_filename, 'w', encoding='utf-8') as f:
json.dump(test_result, f, indent=2, ensure_ascii=False)
print(f"\nFull execution result details saved to: {result_filename}")
except Exception as save_err:
logger.error(f"Failed to save full execution result JSON: {save_err}")
return test_result
except FileNotFoundError:
logger.error(f"Test file not found by executor: {test_file_path}")
return {"success": False, "status": "ERROR", "message": f"Test file not found: {test_file_path}"}
except Exception as e:
logger.error(f"Error running regression test '{test_file_path}': {e}", exc_info=True)
return {"success": False, "status": "ERROR", "message": f"Internal server error during execution: {e}"}
@mcp.tool()
async def discover_test_flows(start_url: str, max_pages_to_crawl: int = 10, headless: bool = True) -> Dict[str, Any]:
"""
Crawls a website starting from a given URL within the same domain, analyzes page content
(DOM, Screenshot), and uses an LLM to suggest potential specific test step descriptions
for each discovered page.
Args:
start_url: The URL to begin crawling from (e.g., 'https://example.com').
max_pages_to_crawl: The maximum number of unique pages to visit (default: 10).
headless: Run the crawler's browser in headless mode (default: True).
Returns:
A dictionary containing the crawl summary, including success status,
pages visited, and a dictionary mapping visited URLs to suggested test step descriptions.
Example: {"success": true, "discovered_steps": {"https://example.com/login": ["Type 'user' into Username field", ...]}}
"""
logger.info(f"Received request to discover test flows starting from: '{start_url}', Max Pages: {max_pages_to_crawl}, Headless: {headless}")
try:
# 1. Instantiate CrawlerAgent
crawler = CrawlerAgent(
llm_client=llm_client,
headless=headless
)
# 2. Run the blocking crawl method in a separate thread
logger.info("Delegating crawler execution to a separate thread...")
crawl_results = await asyncio.to_thread(
crawler.crawl_and_suggest,
start_url,
max_pages_to_crawl
)
logger.info(f"Crawling finished (thread returned). Visited: {crawl_results.get('pages_visited')}, Suggestions: {len(crawl_results.get('discovered_steps', {}))}")
# Return the results dictionary from the crawler
return crawl_results
except Exception as e:
logger.error(f"Error in discover_test_flows tool: {e}", exc_info=True)
return {"success": False, "message": f"Internal server error during crawling: {e}", "discovered_steps": {}}
# --- MCP Resource: List Recorded Tests ---
@mcp.tool()
def list_recorded_tests(project_directory: str) -> List[str]:
"""
Provides a list of available test JSON files in the standard output directory.
Args:
project_directory: The project directory you are currently working in. This is used to identify the test flows of a project
Returns:
test_files: A list of filenames for each test flow (e.g., ["test_login_flow_....json", "test_search_....json"]).
"""
logger.info(f"Providing resource list of tests from '{TEST_OUTPUT_DIR}'")
if not os.path.exists(TEST_OUTPUT_DIR) or not os.path.isdir(TEST_OUTPUT_DIR):
logger.warning(f"Test output directory '{TEST_OUTPUT_DIR}' not found.")
return []
try:
test_files = [
f for f in os.listdir(TEST_OUTPUT_DIR)
if os.path.isfile(os.path.join(TEST_OUTPUT_DIR, f)) and f.endswith(".json") and f.startswith(re.sub(r"[ /]", "_", project_directory))
]
# Optionally return just the test files, excluding execution results
test_files = [f for f in test_files if not f.startswith("execution_result_")]
return sorted(test_files)
except Exception as e:
logger.error(f"Error listing test files in '{TEST_OUTPUT_DIR}': {e}", exc_info=True)
# Re-raise or return empty list? Returning empty is safer for resource.
return []
@mcp.tool()
def get_security_scan(project_directory: str, target_url: str = None, semgrep_config: str = 'auto') -> Dict[str, Any]:
"""
Provides a list of vulnerabilities in the code through static code scanning using semgrep, nuclei and zap.
Also discovers endpoints using ZAP's spider functionality. Try to fix them automatically if you think it is a true positive.
Args:
project_directory: The project directory which you want to scan for security issues. Give absolute path only.
target_url: The target URL for dynamic scanning (ZAP and Nuclei). Required for endpoint discovery.
semgrep_config: The config for semgrep scans. Default: 'auto'
Returns:
Dict containing:
- vulnerabilities: List of vulnerabilities found
- endpoints: List of discovered endpoints (if target_url provided)
"""
logging.info("--- Starting Phase 1: Security Scanning ---")
all_findings = []
discovered_endpoints = []
if project_directory:
# Run Semgrep scan
logging.info("--- Running Semgrep Scan ---")
semgrep_findings = run_semgrep(
code_path=project_directory,
config=semgrep_config,
output_dir='./results',
timeout=600
)
if semgrep_findings:
logging.info(f"Completed Semgrep Scan. Found {len(semgrep_findings)} potential issues.")
all_findings.extend(semgrep_findings)
else:
logging.warning("Semgrep scan completed with no findings or failed.")
all_findings.append({"Warning": "Semgrep scan completed with no findings or failed."})
if target_url:
# First, discover endpoints using ZAP spider
logging.info("--- Running Endpoint Discovery ---")
try:
discovered_endpoints = discover_endpoints(
target_url=target_url,
output_dir='./results',
timeout=600 # 10 minutes for discovery
)
logging.info(f"Discovered {len(discovered_endpoints)} endpoints")
except Exception as e:
logging.error(f"Error during endpoint discovery: {e}")
discovered_endpoints = []
# Run ZAP scan
logging.info("--- Running ZAP Scan ---")
try:
zap_findings = run_zap_scan(
target_url=target_url,
output_dir='./results',
scan_mode="baseline" # Using baseline scan for quicker results
)
if zap_findings and not isinstance(zap_findings[0], str):
logging.info(f"Completed ZAP Scan. Found {len(zap_findings)} potential issues.")
all_findings.extend(zap_findings)
else:
logging.warning("ZAP scan completed with no findings or failed.")
all_findings.append({"Warning": "ZAP scan completed with no findings or failed."})
except Exception as e:
logging.error(f"Error during ZAP scan: {e}")
all_findings.append({"Error": f"ZAP scan failed: {str(e)}"})
# Run Nuclei scan
logging.info("--- Running Nuclei Scan ---")
try:
nuclei_findings = run_nuclei(
target_url=target_url,
output_dir='./results'
)
if nuclei_findings and not isinstance(nuclei_findings[0], str):
logging.info(f"Completed Nuclei Scan. Found {len(nuclei_findings)} potential issues.")
all_findings.extend(nuclei_findings)
else:
logging.warning("Nuclei scan completed with no findings or failed.")
all_findings.append({"Warning": "Nuclei scan completed with no findings or failed."})
except Exception as e:
logging.error(f"Error during Nuclei scan: {e}")
all_findings.append({"Error": f"Nuclei scan failed: {str(e)}"})
else:
logging.info("Skipping dynamic scans and endpoint discovery as target_url was not provided.")
else:
logging.info("Skipping scans as project_directory was not provided.")
all_findings.append({"Warning": "Skipping scans as project_directory was not provided"})
logging.info("--- Phase 1: Security Scanning Complete ---")
logging.info("--- Starting Phase 2: Consolidating Results ---")
logging.info(f"Total findings aggregated from all tools: {len(all_findings)}")
# Save the consolidated report
consolidated_report_path = save_report(all_findings, "consolidated", './results/', "consolidated_scan_results")
if consolidated_report_path:
logging.info(f"Consolidated report saved to: {consolidated_report_path}")
print(f"\nConsolidated report saved to: {consolidated_report_path}")
else:
logging.error("Failed to save the consolidated report.")
# Save discovered endpoints if any
if discovered_endpoints:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
endpoints_file = os.path.join('./results', f"discovered_endpoints_{timestamp}.json")
try:
with open(endpoints_file, 'w') as f:
json.dump(discovered_endpoints, f, indent=2)
logging.info(f"Saved discovered endpoints to: {endpoints_file}")
except Exception as e:
logging.error(f"Failed to save endpoints report: {e}")
logging.info("--- Phase 2: Consolidation Complete ---")
logging.info("--- Security Automation Script Finished ---")
return {
"vulnerabilities": all_findings,
"endpoints": discovered_endpoints
}
# --- Running the Server ---
# The actual running is handled by `mcp dev` or `mcp install`.
# No `if __name__ == "__main__": mcp.run()` needed here when using the CLI tools.
logger.info("WebTestAgent MCP Server defined. Run with 'mcp dev mcp_server.py'")
if __name__ == "__main__":
mcp.run()
```
--------------------------------------------------------------------------------
/src/dom/views.py:
--------------------------------------------------------------------------------
```python
# /src/dom/views.py
from dataclasses import dataclass, field, KW_ONLY # Use field for default_factory
from functools import cached_property
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Literal, Tuple
import re # Added for selector generation
# Use relative imports if within the same package structure
from .history.view import CoordinateSet, HashedDomElement, ViewportInfo # Adjusted import
# Placeholder decorator if not using utils.time_execution_sync
def time_execution_sync(label):
def decorator(func):
def wrapper(*args, **kwargs):
# Basic logging
# logger.debug(f"Executing {label}...")
result = func(*args, **kwargs)
# logger.debug(f"Finished {label}.")
return result
return wrapper
return decorator
# Avoid circular import issues
if TYPE_CHECKING:
# This creates a forward reference issue if DOMElementNode itself is in this file.
# We need to define DOMElementNode before DOMBaseNode if DOMBaseNode references it.
# Let's adjust the structure slightly or use string hints.
pass # Forward reference handled by structure/string hints below
@dataclass(frozen=False)
class DOMBaseNode:
# Parent needs to be Optional and potentially use string hint if defined later
parent: Optional['DOMElementNode'] = None # Default to None
is_visible: bool = False # Provide default
@dataclass(frozen=False)
class DOMTextNode(DOMBaseNode):
# --- Field ordering within subclass matters less with KW_ONLY ---
# --- but arguments after the marker MUST be passed by keyword ---
_ : KW_ONLY # <--- Add KW_ONLY marker
# Fields defined in this class (now keyword-only)
text: str
type: str = 'TEXT_NODE'
def has_parent_with_highlight_index(self) -> bool:
current = self.parent
while current is not None:
if current.highlight_index is not None:
return True
current = current.parent
return False
# These visibility checks might be less useful now that JS handles it, but keep for now
def is_parent_in_viewport(self) -> bool:
if self.parent is None:
return False
return self.parent.is_in_viewport
def is_parent_top_element(self) -> bool:
if self.parent is None:
return False
return self.parent.is_top_element
# Define DOMElementNode *before* DOMBaseNode references it fully, or ensure Optional['DOMElementNode'] works
@dataclass(frozen=False)
class DOMElementNode(DOMBaseNode):
"""
Represents an element node in the processed DOM tree.
Includes information about interactivity, visibility, and structure.
"""
tag_name: str = ""
xpath: str = ""
attributes: Dict[str, str] = field(default_factory=dict)
# Use Union with string hint for forward reference if needed, or ensure DOMTextNode is defined first
children: List[Union['DOMElementNode', DOMTextNode]] = field(default_factory=list)
is_interactive: bool = False
is_top_element: bool = False
is_in_viewport: bool = False
shadow_root: bool = False
highlight_index: Optional[int] = None
page_coordinates: Optional[CoordinateSet] = None
viewport_coordinates: Optional[CoordinateSet] = None
viewport_info: Optional[ViewportInfo] = None
css_selector: Optional[str] = None # Added field for robust selector
def __repr__(self) -> str:
# ... (repr logic remains the same) ...
tag_str = f'<{self.tag_name}'
for key, value in self.attributes.items():
# Shorten long values in repr
value_repr = value if len(value) < 50 else value[:47] + '...'
tag_str += f' {key}="{value_repr}"'
tag_str += '>'
extras = []
if self.is_interactive: extras.append('interactive')
if self.is_top_element: extras.append('top')
if self.is_in_viewport: extras.append('in-viewport')
if self.shadow_root: extras.append('shadow-root')
if self.highlight_index is not None: extras.append(f'highlight:{self.highlight_index}')
if self.css_selector: extras.append(f'css:"{self.css_selector[:50]}..."') # Show generated selector
if extras:
tag_str += f' [{", ".join(extras)}]'
return tag_str
@cached_property
def hash(self) -> HashedDomElement:
""" Lazily computes and caches the hash of the element using HistoryTreeProcessor. """
# Use relative import within the method to avoid top-level circular dependencies
from .history.service import HistoryTreeProcessor
# Ensure HistoryTreeProcessor._hash_dom_element exists and is static or accessible
return HistoryTreeProcessor._hash_dom_element(self)
def get_all_text_till_next_clickable_element(self, max_depth: int = -1) -> str:
"""
Recursively collects all text content within this element, stopping descent
if a nested interactive element (with a highlight_index) is encountered.
"""
text_parts = []
def collect_text(node: Union['DOMElementNode', DOMTextNode], current_depth: int) -> None:
if max_depth != -1 and current_depth > max_depth:
return
# Check if the node itself is interactive and not the starting node
if isinstance(node, DOMElementNode) and node is not self and node.highlight_index is not None:
# Stop recursion down this path if we hit an interactive element
return
if isinstance(node, DOMTextNode):
# Only include visible text nodes
if node.is_visible:
text_parts.append(node.text)
elif isinstance(node, DOMElementNode):
# Recursively process children
for child in node.children:
collect_text(child, current_depth + 1)
# Start collection from the element itself
collect_text(self, 0)
# Join collected parts and clean up whitespace
return '\n'.join(filter(None, (tp.strip() for tp in text_parts))).strip()
@time_execution_sync('--clickable_elements_to_string')
def generate_llm_context_string(self,
include_attributes: Optional[List[str]] = None,
max_static_elements_action: int = 50, # Max static elements for action context
max_static_elements_verification: int = 150, # Allow more static elements for verification context
context_purpose: Literal['action', 'verification'] = 'action' # New parameter
) -> Tuple[str, Dict[str, 'DOMElementNode']]:
"""
Generates a string representation of VISIBLE elements tree for LLM context.
Clearly distinguishes interactive elements (with index) from static ones.
Assigns temporary IDs to static elements for later lookup.
Args:
include_attributes: List of specific attributes to include. If None, uses defaults.
max_static_elements_action: Max static elements for 'action' purpose.
max_static_elements_verification: Max static elements for 'verification' purpose.
context_purpose: 'action' (concise) or 'verification' (more inclusive static).
Returns:
Tuple containing:
- The formatted context string.
- A dictionary mapping temporary static IDs (e.g., "s1", "s2")
to the corresponding DOMElementNode objects.
"""
formatted_lines = []
processed_node_ids = set()
static_element_count = 0
nodes_processed_count = 0
static_id_counter = 1 # Counter for temporary static IDs
temp_static_id_map: Dict[str, 'DOMElementNode'] = {} # Map temporary ID to node
max_static_elements = max_static_elements_verification if context_purpose == 'verification' else max_static_elements_action
def get_direct_visible_text(node: DOMElementNode, max_len=10000) -> str:
"""Gets text directly within this node, ignoring children elements."""
texts = []
for child in node.children:
if isinstance(child, DOMTextNode) and child.is_visible:
texts.append(child.text.strip())
full_text = ' '.join(filter(None, texts))
if len(full_text) > max_len:
return full_text[:max_len-3] + "..."
return full_text
def get_parent_hint(node: DOMElementNode) -> Optional[str]:
"""Gets a hint string for the nearest identifiable parent."""
parent = node.parent
if isinstance(parent, DOMElementNode):
parent_attrs = parent.attributes
hint_parts = []
if parent_attrs.get('id'):
hint_parts.append(f"id=\"{parent_attrs['id'][:20]}\"") # Limit length
if parent_attrs.get('data-testid'):
hint_parts.append(f"data-testid=\"{parent_attrs['data-testid'][:20]}\"")
# Add class hint only if specific? Maybe too noisy. Start with id/testid.
# if parent_attrs.get('class'):
# stable_classes = [c for c in parent_attrs['class'].split() if len(c) > 3 and not c.isdigit()]
# if stable_classes: hint_parts.append(f"class=\"{stable_classes[0][:15]}...\"") # Show first stable class
if hint_parts:
return f"(inside: <{parent.tag_name} {' '.join(hint_parts)}>)"
return None
def process_node(node: Union['DOMElementNode', DOMTextNode], depth: int) -> None:
nonlocal static_element_count, nodes_processed_count, static_id_counter # Allow modification
# Skip if already processed or not an element
if not isinstance(node, DOMElementNode): return
nodes_processed_count += 1
node_id = id(node)
if node_id in processed_node_ids: return
processed_node_ids.add(node_id)
is_node_visible = node.is_visible
visibility_marker = "" if is_node_visible else " (Not Visible)"
should_add_current_node = False
line_to_add = ""
is_interactive = node.highlight_index is not None
temp_static_id_assigned = None # Track if ID was assigned to this node
indent = ' ' * depth
# --- Attribute Extraction (Common logic) ---
attributes_to_show = {}
default_attrs = ['id', 'name', 'class', 'aria-label', 'placeholder', 'role', 'type', 'value', 'title', 'alt', 'href', 'data-testid', 'data-value']
attrs_to_check = include_attributes if include_attributes else default_attrs
extract_attrs_for_this_node = is_interactive or (context_purpose == 'verification')
if extract_attrs_for_this_node:
for attr_key in attrs_to_check:
if attr_key in node.attributes and node.attributes[attr_key] is not None: # Check for not None
# Simple check to exclude extremely long class lists for brevity, unless it's ID/testid
if attr_key == 'class' and len(node.attributes[attr_key]) > 20 and context_purpose == 'action':
attributes_to_show[attr_key] = node.attributes[attr_key][:97] + "..."
else:
attributes_to_show[attr_key] = node.attributes[attr_key]
attrs_str = ""
if attributes_to_show:
parts = []
for key, value in attributes_to_show.items():
value_str = str(value) # Ensure it's a string
# Limit length for display
display_value = value_str if len(value_str) < 50 else value_str[:47] + '...'
# *** CORRECT HTML ESCAPING for attribute value strings ***
display_value = display_value.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"')
parts.append(f'{key}="{display_value}"')
attrs_str = " ".join(parts)
# --- Format line based on Interactive vs. Static ---
if is_interactive:
# == INTERACTIVE ELEMENT == (Always include)
text_content = node.get_all_text_till_next_clickable_element()
text_content = ' '.join(text_content.split()) if text_content else ""
# Truncate long text for display
if len(text_content) > 150: text_content = text_content[:147] + "..."
line_to_add = f"{indent}[{node.highlight_index}]<{node.tag_name}"
if attrs_str: line_to_add += f" {attrs_str}"
if text_content: line_to_add += f">{text_content}</{node.tag_name}>"
else: line_to_add += " />"
line_to_add += visibility_marker
should_add_current_node = True
elif static_element_count < max_static_elements:
# == VISIBLE STATIC ELEMENT ==
text_content = get_direct_visible_text(node)
include_this_static = False
# Determine if static node is relevant for verification
if context_purpose == 'verification':
common_static_tags = {'p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'span', 'div', 'li', 'label', 'td', 'th', 'strong', 'em', 'dt', 'dd'}
# Include if common tag OR has text OR *has attributes calculated in attrs_str*
if node.tag_name in common_static_tags or text_content or attrs_str:
include_this_static = True
if not text_content:
include_this_static = False
if include_this_static:
# --- Assign temporary static ID ---
current_static_id = f"s{static_id_counter}"
temp_static_id_map[current_static_id] = node
temp_static_id_assigned = current_static_id # Mark that ID was assigned
static_id_counter += 1
# *** Start building the line ***
line_to_add = f"{indent}<{node.tag_name}"
# *** CRUCIAL: Add the calculated attributes string ***
if attrs_str:
line_to_add += f" {attrs_str}"
# --- Add the static ID attribute to the string ---
line_to_add += f' data-static-id="{current_static_id}"'
# *** Add the static marker ***
line_to_add += " (Static)"
line_to_add += visibility_marker
# *** Add parent hint ONLY if element lacks key identifiers ***
node_attrs = node.attributes # Use original attributes for this check
has_key_identifier = node_attrs.get('id') or node_attrs.get('data-testid') or node_attrs.get('name')
if not has_key_identifier:
parent_hint = get_parent_hint(node)
if parent_hint:
line_to_add += f" {parent_hint}"
# *** Add text content and close tag ***
if text_content:
line_to_add += f">{text_content}</{node.tag_name}>"
else:
line_to_add += " />"
should_add_current_node = True
static_element_count += 1
# --- Add the formatted line if needed ---
if should_add_current_node:
formatted_lines.append(line_to_add)
# logger.debug(f"Added line: {line_to_add}") # Optional debug
# --- ALWAYS Recurse into children (unless static limit hit) ---
# We recurse even if the parent wasn't added, because children might be visible/interactive
if static_element_count >= max_static_elements:
# Stop recursing down static branches if limit is hit
pass
else:
for child in node.children:
# Pass DOMElementNode or DOMTextNode directly
process_node(child, depth + 1)
# Start processing from the root element
process_node(self, 0)
# logger.debug(f"Finished generate_llm_context_string. Processed {nodes_processed_count} nodes. Added {len(formatted_lines)} lines.") # Log summary
output_str = '\n'.join(formatted_lines)
if static_element_count >= max_static_elements:
output_str += f"\n{ ' ' * 0 }... (Static element list truncated after {max_static_elements} entries)"
return output_str, temp_static_id_map
def get_file_upload_element(self, check_siblings: bool = True) -> Optional['DOMElementNode']:
# Check if current element is a file input
if self.tag_name == 'input' and self.attributes.get('type') == 'file':
return self
# Check children
for child in self.children:
if isinstance(child, DOMElementNode):
result = child.get_file_upload_element(check_siblings=False)
if result:
return result
# Check siblings only for the initial call
if check_siblings and self.parent:
for sibling in self.parent.children:
if sibling is not self and isinstance(sibling, DOMElementNode):
result = sibling.get_file_upload_element(check_siblings=False)
if result:
return result
return None
# Type alias for the selector map
SelectorMap = Dict[int, DOMElementNode]
@dataclass
class DOMState:
"""Holds the state of the processed DOM at a point in time."""
element_tree: DOMElementNode
selector_map: SelectorMap
```
--------------------------------------------------------------------------------
/src/agents/crawler_agent.py:
--------------------------------------------------------------------------------
```python
# /src/crawler_agent.py
import logging
import time
from urllib.parse import urlparse, urljoin
from typing import List, Set, Dict, Any, Optional
import asyncio # For potential async operations if BrowserController becomes async
import re
import os
import json
from pydantic import BaseModel, Field
# Use relative imports within the package if applicable, or adjust paths
from ..browser.browser_controller import BrowserController
from ..llm.llm_client import LLMClient
logger = logging.getLogger(__name__)
# --- Pydantic Schema for LLM Response ---
class SuggestedTestStepsSchema(BaseModel):
"""Schema for suggested test steps relevant to the current page."""
suggested_test_steps: List[str] = Field(..., description="List of 3-5 specific, actionable test step descriptions (like 'Click button X', 'Type Y into Z', 'Verify text A') relevant to the current page context.")
reasoning: str = Field(..., description="Brief reasoning for why these steps are relevant to the page content and URL.")
# --- Crawler Agent ---
class CrawlerAgent:
"""
Crawls a given domain, identifies unique pages, and uses an LLM
to suggest potential test flows for each discovered page.
"""
def __init__(self, llm_client: LLMClient, headless: bool = True, politeness_delay_sec: float = 1.0):
self.llm_client = llm_client
self.headless = headless
self.politeness_delay = politeness_delay_sec
self.browser_controller: Optional[BrowserController] = None
# State for crawling
self.base_domain: Optional[str] = None
self.queue: List[str] = []
self.visited_urls: Set[str] = set()
self.discovered_steps: Dict[str, List[str]] = {}
def _normalize_url(self, url: str) -> str:
"""Removes fragments and trailing slashes for consistent URL tracking."""
parsed = urlparse(url)
# Rebuild without fragment, ensure path starts with / if root
path = parsed.path if parsed.path else '/'
if path.endswith('/'):
path = path[:-1] # Remove trailing slash unless it's just '/'
# Ensure scheme and netloc are present
scheme = parsed.scheme if parsed.scheme else 'http' # Default to http if missing? Or raise error? Let's default.
netloc = parsed.netloc
if not netloc:
logger.warning(f"URL '{url}' missing network location (domain). Skipping.")
return None # Invalid URL for crawling
# Query params are usually important, keep them
query = parsed.query
# Reconstruct
rebuilt_url = f"{scheme}://{netloc}{path}"
if query:
rebuilt_url += f"?{query}"
return rebuilt_url.lower() # Use lowercase for comparison
def _get_domain(self, url: str) -> Optional[str]:
"""Extracts the network location (domain) from a URL."""
try:
return urlparse(url).netloc.lower()
except Exception:
return None
def _is_valid_url(self, url: str) -> bool:
"""Basic check if a URL seems valid for crawling."""
try:
parsed = urlparse(url)
# Must have scheme (http/https) and netloc (domain)
return all([parsed.scheme in ['http', 'https'], parsed.netloc])
except Exception:
return False
def _extract_links(self, current_url: str) -> Set[str]:
"""Extracts and normalizes unique, valid links from the current page."""
if not self.browser_controller or not self.browser_controller.page:
logger.error("Browser not available for link extraction.")
return set()
extracted_links = set()
try:
# Use Playwright's locator to find all anchor tags
links = self.browser_controller.page.locator('a[href]').all()
logger.debug(f"Found {len(links)} potential link elements on {current_url}.")
for link_locator in links:
try:
href = link_locator.get_attribute('href')
if href:
# Resolve relative URLs against the current page's URL
absolute_url = urljoin(current_url, href.strip())
normalized_url = self._normalize_url(absolute_url)
if normalized_url and self._is_valid_url(normalized_url):
# logger.debug(f" Found link: {href} -> {normalized_url}")
extracted_links.add(normalized_url)
# else: logger.debug(f" Skipping invalid/malformed link: {href} -> {normalized_url}")
except Exception as link_err:
# Log error getting attribute but continue with others
logger.warning(f"Could not get href for a link element on {current_url}: {link_err}")
continue # Skip this link
except Exception as e:
logger.error(f"Error extracting links from {current_url}: {e}", exc_info=True)
logger.info(f"Extracted {len(extracted_links)} unique, valid, normalized links from {current_url}.")
return extracted_links
def _get_test_step_suggestions(self,
page_url: str,
dom_context_str: Optional[str],
screenshot_bytes: Optional[bytes]
) -> List[str]:
"""Asks the LLM to suggest specific test steps based on page URL, DOM, and screenshot."""
logger.info(f"Requesting LLM suggestions for page: {page_url} (using DOM/Screenshot context)")
prompt = f"""
You are an AI Test Analyst identifying valuable test scenarios by suggesting specific test steps.
The crawler is currently on the web page:
URL: {page_url}
Analyze the following page context:
1. **URL & Page Purpose:** Infer the primary purpose of this page (e.g., Login, Blog Post, Product Details, Form Submission, Search Results, Homepage).
2. **Visible DOM Elements:** Review the HTML snippet of visible elements. Note forms, primary action buttons (Submit, Add to Cart, Subscribe), key content areas, inputs, etc. Interactive elements are marked `[index]`, static with `(Static)`.
3. **Screenshot:** Analyze the visual layout, focusing on interactive elements and prominent information relevant to the page's purpose.
**Visible DOM Context:**
```html
{dom_context_str if dom_context_str else "DOM context not available."}
```
{f"**Screenshot Analysis:** Please analyze the attached screenshot for layout, visible text, forms, and key interactive elements." if screenshot_bytes else "**Note:** No screenshot provided."}
**Your Task:**
Based on the inferred purpose and the page context (URL, DOM, Screenshot), suggest **one or two short sequences (totaling 4-7 steps)** of specific, actionable test steps representing **meaningful user interactions or verifications** related to the page's **core functionality**.
**Step Description Requirements:**
* Each step should be a single, clear instruction (e.g., "Click 'Login' button", "Type '[email protected]' into 'Email' field", "Verify 'Welcome Back!' message is displayed").
* Describe target elements clearly using visual labels, placeholders, or roles (e.g., 'Username field', 'Add to Cart button', 'Subscribe to newsletter checkbox'). **Do NOT include CSS selectors or indices `[index]`**.
* **Prioritize sequences:** Group related actions together logically (e.g., fill form fields -> click submit; select options -> add to cart).
* **Focus on core function:** Test the main reason the page exists (logging in, submitting data, viewing specific content details, adding an item, completing a search, signing up, etc.).
* **Include Verifications:** Crucially, add steps to verify expected outcomes after actions (e.g., "Verify success message 'Item Added' appears", "Verify error message 'Password required' is shown", "Verify user is redirected to dashboard page", "Verify shopping cart count increases").
* **AVOID:** Simply listing navigation links (header, footer, sidebar) unless they are part of a specific task *initiated* on this page (like password recovery). Avoid generic actions ("Click image", "Click text") without clear purpose or verification.
**Examples of GOOD Step Sequences:**
* Login Page: `["Type 'testuser' into Username field", "Type 'wrongpass' into Password field", "Click Login button", "Verify 'Invalid credentials' error message is visible"]`
* Product Page: `["Select 'Red' from Color dropdown", "Click 'Add to Cart' button", "Verify cart icon shows '1 item'", "Navigate to the shopping cart page"]`
* Blog Page (if comments enabled): `["Scroll down to the comments section", "Type 'Great post!' into the comment input box", "Click the 'Submit Comment' button", "Verify 'Comment submitted successfully' message appears"]`
* Newsletter Signup Form: `["Enter 'John Doe' into the Full Name field", "Enter '[email protected]' into the Email field", "Click the 'Subscribe' button", "Verify confirmation text 'Thanks for subscribing!' is displayed"]`
**Examples of BAD/LOW-VALUE Steps (to avoid):**
* `["Click Home link", "Click About Us link", "Click Contact link"]` (Just navigation, low value unless testing navigation itself specifically)
* `["Click the first image", "Click the second paragraph"]` (No clear purpose or verification)
* `["Type text into search bar"]` (Incomplete - what text? what next? add submit/verify)
**Output Requirements:**
- Provide a JSON object matching the required schema (`SuggestedTestStepsSchema`).
- The `suggested_test_steps` list should contain 4-7 specific steps, ideally forming 1-2 meaningful sequences.
- Provide brief `reasoning` explaining *why* these steps test the core function.
Respond ONLY with the JSON object matching the schema.
"""
# Call generate_json, passing image_bytes if available
response_obj = self.llm_client.generate_json(
SuggestedTestStepsSchema, # Use the updated schema class
prompt,
image_bytes=screenshot_bytes # Pass the image bytes here
)
if isinstance(response_obj, SuggestedTestStepsSchema):
logger.debug(f"LLM suggested steps for {page_url}: {response_obj.suggested_test_steps} (Reason: {response_obj.reasoning})")
# Validate the response list
if response_obj.suggested_test_steps and isinstance(response_obj.suggested_test_steps, list):
valid_steps = [step for step in response_obj.suggested_test_steps if isinstance(step, str) and step.strip()]
if len(valid_steps) != len(response_obj.suggested_test_steps):
logger.warning(f"LLM response for {page_url} contained invalid step entries. Using only valid ones.")
return valid_steps
else:
logger.warning(f"LLM did not return a valid list of steps for {page_url}.")
return []
elif isinstance(response_obj, str): # Handle error string
logger.error(f"LLM suggestion failed for {page_url}: {response_obj}")
return []
else: # Handle unexpected type
logger.error(f"Unexpected response type from LLM for {page_url}: {type(response_obj)}")
return []
def crawl_and_suggest(self, start_url: str, max_pages: int = 10) -> Dict[str, Any]:
"""
Starts the crawling process from the given URL.
Args:
start_url: The initial URL to start crawling from.
max_pages: The maximum number of unique pages to visit and get suggestions for.
Returns:
A dictionary containing the results:
{
"success": bool,
"message": str,
"start_url": str,
"base_domain": str,
"pages_visited": int,
"discovered_steps": Dict[str, List[str]] # {url: [flow1, flow2,...]}
}
"""
logger.info(f"Starting crawl from '{start_url}', max pages: {max_pages}")
crawl_result = {
"success": False,
"message": "Crawl initiated.",
"start_url": start_url,
"base_domain": None,
"pages_visited": 0,
"discovered_steps": {}
}
# --- Initialization ---
self.queue = []
self.visited_urls = set()
self.discovered_steps = {}
normalized_start_url = self._normalize_url(start_url)
if not normalized_start_url or not self._is_valid_url(normalized_start_url):
crawl_result["message"] = f"Invalid start URL provided: {start_url}"
logger.error(crawl_result["message"])
return crawl_result
self.base_domain = self._get_domain(normalized_start_url)
if not self.base_domain:
crawl_result["message"] = f"Could not extract domain from start URL: {start_url}"
logger.error(crawl_result["message"])
return crawl_result
crawl_result["base_domain"] = self.base_domain
self.queue.append(normalized_start_url)
logger.info(f"Base domain set to: {self.base_domain}")
try:
# --- Setup Browser ---
logger.info("Starting browser for crawler...")
self.browser_controller = BrowserController(headless=self.headless)
self.browser_controller.start()
if not self.browser_controller.page:
raise RuntimeError("Failed to initialize browser page for crawler.")
# --- Crawling Loop ---
while self.queue and len(self.visited_urls) < max_pages:
current_url = self.queue.pop(0)
# Skip if already visited
if current_url in self.visited_urls:
logger.debug(f"Skipping already visited URL: {current_url}")
continue
# Check if it belongs to the target domain
current_domain = self._get_domain(current_url)
if current_domain != self.base_domain:
logger.debug(f"Skipping URL outside base domain ({self.base_domain}): {current_url}")
continue
logger.info(f"Visiting ({len(self.visited_urls) + 1}/{max_pages}): {current_url}")
self.visited_urls.add(current_url)
crawl_result["pages_visited"] += 1
dom_context_str: Optional[str] = None
screenshot_bytes: Optional[bytes] = None
# Navigate
try:
self.browser_controller.goto(current_url)
# Optional: Add wait for load state if needed, goto has basic wait
self.browser_controller.page.wait_for_load_state('domcontentloaded', timeout=15000)
actual_url_after_nav = self.browser_controller.get_current_url() # Get the URL we actually landed on
# --- >>> ADDED DOMAIN CHECK AFTER NAVIGATION <<< ---
actual_domain = self._get_domain(actual_url_after_nav)
if actual_domain != self.base_domain:
logger.warning(f"Redirected outside base domain! "
f"Initial: {current_url}, Final: {actual_url_after_nav} ({actual_domain}). Skipping further processing for this page.")
# Optional: Add the actual off-domain URL to visited to prevent loops if it links back?
off_domain_normalized = self._normalize_url(actual_url_after_nav)
if off_domain_normalized:
self.visited_urls.add(off_domain_normalized)
time.sleep(self.politeness_delay) # Still add delay
continue # Skip context gathering, link extraction, suggestions for this off-domain page
# --- Gather Context (DOM + Screenshot) ---
try:
logger.debug(f"Gathering DOM state for {current_url}...")
dom_state = self.browser_controller.get_structured_dom()
if dom_state and dom_state.element_tree:
dom_context_str, _ = dom_state.element_tree.generate_llm_context_string(context_purpose='verification') # Use verification context (more static elements)
logger.debug(f"DOM context string generated (length: {len(dom_context_str)}).")
else:
logger.warning(f"Could not get structured DOM for {current_url}.")
dom_context_str = "Error retrieving DOM structure."
logger.debug(f"Taking screenshot for {current_url}...")
screenshot_bytes = self.browser_controller.take_screenshot()
if not screenshot_bytes:
logger.warning(f"Failed to take screenshot for {current_url}.")
except Exception as context_err:
logger.error(f"Failed to gather context (DOM/Screenshot) for {current_url}: {context_err}")
dom_context_str = f"Error gathering context: {context_err}"
screenshot_bytes = None # Ensure screenshot is None if context gathering failed
except Exception as nav_e:
logger.warning(f"Failed to navigate to {current_url}: {nav_e}. Skipping this page.")
# Don't add links or suggestions if navigation fails
continue # Skip to next URL in queue
# Extract Links
new_links = self._extract_links(current_url)
for link in new_links:
if link not in self.visited_urls and self._get_domain(link) == self.base_domain:
if link not in self.queue: # Add only if not already queued
self.queue.append(link)
# --- Get LLM Suggestions (using gathered context) ---
suggestions = self._get_test_step_suggestions(
current_url,
dom_context_str,
screenshot_bytes
)
if suggestions:
self.discovered_steps[current_url] = suggestions
# Politeness delay
logger.debug(f"Waiting {self.politeness_delay}s before next page...")
time.sleep(self.politeness_delay)
# --- Loop End ---
crawl_result["success"] = True
if len(self.visited_urls) >= max_pages:
crawl_result["message"] = f"Crawl finished: Reached max pages limit ({max_pages})."
logger.info(crawl_result["message"])
elif not self.queue:
crawl_result["message"] = f"Crawl finished: Explored all reachable pages within domain ({len(self.visited_urls)} visited)."
logger.info(crawl_result["message"])
else: # Should not happen unless error
crawl_result["message"] = "Crawl finished unexpectedly."
crawl_result["discovered_steps"] = self.discovered_steps
except Exception as e:
logger.critical(f"Critical error during crawl process: {e}", exc_info=True)
crawl_result["message"] = f"Crawler failed with error: {e}"
crawl_result["success"] = False
finally:
logger.info("--- Ending Crawl ---")
if self.browser_controller:
self.browser_controller.close()
self.browser_controller = None
logger.info(f"Crawl Summary: Visited {crawl_result['pages_visited']} pages. Found suggestions for {len(crawl_result.get('discovered_steps', {}))} pages.")
return crawl_result
```
--------------------------------------------------------------------------------
/src/dom/service.py:
--------------------------------------------------------------------------------
```python
# /src/dom/service.py
import gc
import json
import logging
from dataclasses import dataclass
from importlib import resources # Use importlib.resources
from typing import TYPE_CHECKING, Optional, Tuple, Dict, List
import re
# Use relative imports if within the same package structure
from .views import (
DOMBaseNode,
DOMElementNode,
DOMState,
DOMTextNode,
SelectorMap,
ViewportInfo, # Added ViewportInfo here
CoordinateSet # Added CoordinateSet
)
# Removed utils import assuming time_execution_async is defined elsewhere or removed for brevity
# from ..utils import time_execution_async # Example relative import if utils is one level up
if TYPE_CHECKING:
from patchright.sync_api import Page # Use sync_api for this repo
logger = logging.getLogger(__name__)
# Decorator placeholder if not using utils.time_execution_async
def time_execution_async(label):
def decorator(func):
# In a sync context, this decorator needs adjustment or removal
# For simplicity here, we'll just make it pass through in the sync version
def wrapper(*args, **kwargs):
# logger.debug(f"Executing {label}...") # Basic logging
result = func(*args, **kwargs)
# logger.debug(f"Finished {label}.") # Basic logging
return result
return wrapper
return decorator
class DomService:
def __init__(self, page: 'Page'):
self.page = page
self.xpath_cache = {} # Consider if this cache is still needed/used effectively
# Correctly load JS using importlib.resources relative to this file
try:
# Assuming buildDomTree.js is in the same directory 'dom'
with resources.path(__package__, 'buildDomTree.js') as js_path:
self.js_code = js_path.read_text(encoding='utf-8')
logger.debug("buildDomTree.js loaded successfully.")
except FileNotFoundError:
logger.error("buildDomTree.js not found in the 'dom' package directory!")
raise
except Exception as e:
logger.error(f"Error loading buildDomTree.js: {e}", exc_info=True)
raise
# region - Clickable elements
@time_execution_async('--get_clickable_elements')
def get_clickable_elements(
self,
highlight_elements: bool = True,
focus_element: int = -1,
viewport_expansion: int = 0,
) -> DOMState:
"""Gets interactive elements and DOM structure. Sync version."""
logger.debug(f"Calling _build_dom_tree with highlight={highlight_elements}, focus={focus_element}, expansion={viewport_expansion}")
# In sync context, _build_dom_tree should be sync
element_tree, selector_map = self._build_dom_tree(highlight_elements, focus_element, viewport_expansion)
return DOMState(element_tree=element_tree, selector_map=selector_map)
# Removed get_cross_origin_iframes for brevity, can be added back if needed
# @time_execution_async('--build_dom_tree') # Adjust decorator if needed for sync
def _build_dom_tree(
self,
highlight_elements: bool,
focus_element: int,
viewport_expansion: int,
) -> Tuple[DOMElementNode, SelectorMap]:
"""Builds the DOM tree by executing JS in the browser. Sync version."""
logger.debug("Executing _build_dom_tree...")
if self.page.evaluate('1+1') != 2:
raise ValueError('The page cannot evaluate javascript code properly')
if self.page.url == 'about:blank' or self.page.url == '':
logger.info("Page URL is blank, returning empty DOM structure.")
# short-circuit if the page is a new empty tab for speed
return (
DOMElementNode(
tag_name='body',
xpath='',
attributes={},
children=[],
is_visible=False,
parent=None,
),
{},
)
debug_mode = logger.getEffectiveLevel() <= logging.DEBUG
args = {
'doHighlightElements': highlight_elements,
'focusHighlightIndex': focus_element,
'viewportExpansion': viewport_expansion,
'debugMode': debug_mode,
}
logger.debug(f"Evaluating buildDomTree.js with args: {args}")
try:
# Use evaluate() directly in sync context
eval_page: dict = self.page.evaluate(f"({self.js_code})", args)
except Exception as e:
logger.error(f"Error evaluating buildDomTree.js: {type(e).__name__}: {e}", exc_info=False) # Less verbose logging
logger.debug(f"JS Code Snippet (first 500 chars):\n{self.js_code[:500]}...") # Log JS snippet on error
# Try to get page state for context
try:
page_url = self.page.url
page_title = self.page.title()
logger.error(f"Error occurred on page: URL='{page_url}', Title='{page_title}'")
except Exception as page_state_e:
logger.error(f"Could not get page state after JS error: {page_state_e}")
raise RuntimeError(f"Failed to evaluate DOM building script: {e}") from e # Re-raise a standard error
# Only log performance metrics in debug mode
if debug_mode and 'perfMetrics' in eval_page:
logger.debug(
'DOM Tree Building Performance Metrics for: %s\n%s',
self.page.url,
json.dumps(eval_page['perfMetrics'], indent=2),
)
if 'map' not in eval_page or 'rootId' not in eval_page:
logger.error(f"Invalid structure returned from buildDomTree.js: Missing 'map' or 'rootId'. Response keys: {eval_page.keys()}")
# Log more details if possible
logger.error(f"JS Eval Response Snippet: {str(eval_page)[:1000]}...")
# Return empty structure to prevent downstream errors
return (DOMElementNode(tag_name='body', xpath='', attributes={}, children=[], is_visible=False, parent=None), {})
# raise ValueError("Invalid structure returned from DOM building script.")
# Use sync _construct_dom_tree
return self._construct_dom_tree(eval_page)
# @time_execution_async('--construct_dom_tree') # Adjust decorator if needed for sync
def _construct_dom_tree(
self,
eval_page: dict,
) -> Tuple[DOMElementNode, SelectorMap]:
"""Constructs the Python DOM tree from the JS map. Sync version."""
logger.debug("Constructing Python DOM tree from JS map...")
js_node_map = eval_page['map']
js_root_id = eval_page.get('rootId') # Use .get for safety
if js_root_id is None:
logger.error("JS evaluation result missing 'rootId'. Cannot build tree.")
# Return empty structure
return (DOMElementNode(tag_name='body', xpath='', attributes={}, children=[], is_visible=False, parent=None), {})
selector_map: SelectorMap = {}
node_map: Dict[str, DOMBaseNode] = {} # Use string keys consistently
# Iterate through the JS map provided by the browser script
for id_str, node_data in js_node_map.items():
if not isinstance(node_data, dict):
logger.warning(f"Skipping invalid node data (not a dict) for ID: {id_str}")
continue
node, children_ids_str = self._parse_node(node_data)
if node is None:
continue # Skip nodes that couldn't be parsed
node_map[id_str] = node # Store with string ID
# If the node is an element node with a highlight index, add it to the selector map
if isinstance(node, DOMElementNode) and node.highlight_index is not None:
selector_map[node.highlight_index] = node
# Link children to this node if it's an element node
if isinstance(node, DOMElementNode):
for child_id_str in children_ids_str:
child_node = node_map.get(child_id_str) # Use .get() for safety
if child_node:
# Set the parent reference on the child node
child_node.parent = node
# Add the child node to the current node's children list
node.children.append(child_node)
else:
# This can happen if a child node was invalid or filtered out
logger.debug(f"Child node with ID '{child_id_str}' not found in node_map while processing parent '{id_str}'.")
# Retrieve the root node using the root ID from the evaluation result
root_node = node_map.get(str(js_root_id))
# Clean up large intermediate structures
del node_map
del js_node_map
gc.collect()
# Validate the root node
if root_node is None or not isinstance(root_node, DOMElementNode):
logger.error(f"Failed to find valid root DOMElementNode with ID '{js_root_id}'.")
# Return a default empty body node to avoid crashes
return (DOMElementNode(tag_name='body', xpath='', attributes={}, children=[], is_visible=False, parent=None), selector_map)
logger.debug("Finished constructing Python DOM tree.")
return root_node, selector_map
def _parse_node(
self,
node_data: dict,
) -> Tuple[Optional[DOMBaseNode], List[str]]: # Return string IDs
"""Parses a single node dictionary from JS into a Python DOM object. Sync version."""
if not node_data:
return None, []
node_type = node_data.get('type') # Check if it's explicitly a text node
if node_type == 'TEXT_NODE':
# Handle Text Nodes
text = node_data.get('text', '')
if not text: # Skip empty text nodes early
return None, []
text_node = DOMTextNode(
text=text,
is_visible=node_data.get('isVisible', False), # Use .get for safety
parent=None, # Parent set later during construction
)
return text_node, []
elif 'tagName' in node_data:
# Handle Element Nodes
tag_name = node_data['tagName']
# Process coordinates if they exist (using Pydantic models from view)
page_coords_data = node_data.get('pageCoordinates')
viewport_coords_data = node_data.get('viewportCoordinates')
viewport_info_data = node_data.get('viewportInfo')
page_coordinates = CoordinateSet(**page_coords_data) if page_coords_data else None
viewport_coordinates = CoordinateSet(**viewport_coords_data) if viewport_coords_data else None
viewport_info = ViewportInfo(**viewport_info_data) if viewport_info_data else None
element_node = DOMElementNode(
tag_name=tag_name.lower(), # Ensure lowercase
xpath=node_data.get('xpath', ''),
attributes=node_data.get('attributes', {}),
children=[], # Children added later
is_visible=node_data.get('isVisible', False),
is_interactive=node_data.get('isInteractive', False),
is_top_element=node_data.get('isTopElement', False),
is_in_viewport=node_data.get('isInViewport', False),
highlight_index=node_data.get('highlightIndex'), # Can be None
shadow_root=node_data.get('shadowRoot', False),
parent=None, # Parent set later
# Add coordinate fields
page_coordinates=page_coordinates,
viewport_coordinates=viewport_coordinates,
viewport_info=viewport_info,
# Enhanced CSS selector added later if needed
css_selector=None,
)
# Children IDs are strings from the JS map
children_ids_str = node_data.get('children', [])
# Basic validation
if not isinstance(children_ids_str, list):
logger.warning(f"Invalid children format for node {node_data.get('xpath')}, expected list, got {type(children_ids_str)}. Treating as empty.")
children_ids_str = []
return element_node, [str(cid) for cid in children_ids_str] # Ensure IDs are strings
else:
# Skip nodes that are neither TEXT_NODE nor have a tagName (e.g., comments processed out by JS)
logger.debug(f"Skipping node data without 'type' or 'tagName': {str(node_data)[:100]}...")
return None, []
# Add the helper to generate enhanced CSS selectors (adapted from BrowserContext)
# This could also live in a dedicated selector utility class/module
@staticmethod
def _enhanced_css_selector_for_element(element: DOMElementNode) -> str:
"""
Generates a more robust CSS selector, prioritizing stable attributes.
RECORDER FOCUS: Prioritize ID, data-testid, name, stable classes. Fallback carefully.
"""
if not isinstance(element, DOMElementNode):
return ''
# Escape CSS identifiers (simple version, consider edge cases)
def escape_css(value):
if not value: return ''
# Basic escape for characters that are problematic in unquoted identifiers/strings
# See: https://developer.mozilla.org/en-US/docs/Web/CSS/string#escaping_characters
# This is NOT exhaustive but covers common cases.
return re.sub(r'([!"#$%&\'()*+,./:;<=>?@\[\\\]^`{|}~])', r'\\\1', value)
# --- Attribute Priority Order ---
# 1. ID (if reasonably unique-looking)
if 'id' in element.attributes and element.attributes['id']:
element_id = element.attributes['id'].strip()
if element_id and not element_id.isdigit() and ' ' not in element_id and ':' not in element_id:
escaped_id = escape_css(element_id)
selector = f"#{escaped_id}"
# If ID seems generic, add tag name
if len(element_id) < 6 and element.tag_name not in ['div', 'span']: # Don't add for generic containers unless ID is short
return f"{element.tag_name}{selector}"
return selector
# 2. Stable Data Attributes
for test_attr in ['data-testid', 'data-test-id', 'data-cy', 'data-qa']:
if test_attr in element.attributes and element.attributes[test_attr]:
val = element.attributes[test_attr].strip()
if val:
escaped_val = escape_css(val)
selector = f"[{test_attr}='{escaped_val}']"
# Add tag name if value seems generic
if len(val) < 5:
return f"{element.tag_name}{selector}"
return selector
# 3. Name Attribute
if 'name' in element.attributes and element.attributes['name']:
name_val = element.attributes['name'].strip()
if name_val:
escaped_name = escape_css(name_val)
selector = f"{element.tag_name}[name='{escaped_name}']"
return selector
# 4. Aria-label
if 'aria-label' in element.attributes and element.attributes['aria-label']:
aria_label = element.attributes['aria-label'].strip()
# Ensure label is reasonably specific (not just whitespace or very short)
if aria_label and len(aria_label) > 2 and len(aria_label) < 80:
escaped_label = escape_css(aria_label)
selector = f"{element.tag_name}[aria-label='{escaped_label}']"
return selector
# 5. Placeholder (for inputs)
if element.tag_name == 'input' and 'placeholder' in element.attributes and element.attributes['placeholder']:
placeholder = element.attributes['placeholder'].strip()
if placeholder:
escaped_placeholder = escape_css(placeholder)
selector = f"input[placeholder='{escaped_placeholder}']"
return selector
# --- Text Content Strategy (Use cautiously) ---
# Get DIRECT, visible text content of the element itself
direct_text = ""
if element.is_visible: # Only consider text if element is visible
texts = []
for child in element.children:
if isinstance(child, DOMTextNode) and child.is_visible:
texts.append(child.text.strip())
direct_text = ' '.join(filter(None, texts)).strip()
# 6. Specific Text Content (if short, unique-looking, and element type is suitable)
suitable_text_tags = {'button', 'a', 'span', 'label', 'legend', 'h1', 'h2', 'h3', 'h4', 'p', 'li', 'td', 'th', 'dt', 'dd'}
if direct_text and element.tag_name in suitable_text_tags and 2 < len(direct_text) < 60: # Avoid overly long or short text
# Basic check for uniqueness (could be improved by checking siblings)
# Check if it looks like dynamic content (e.g., numbers only, dates) - skip if so
if not direct_text.isdigit() and not re.match(r'^\$?[\d,.]+$', direct_text): # Avoid pure numbers/prices
# Use Playwright's text selector (escapes internally)
# Note: This requires Playwright >= 1.15 or so for :text pseudo-class
# Using :has-text is generally safer as it looks within descendants too,
# but here we specifically want the *direct* text match.
# Let's try combining tag and text for specificity.
# Playwright handles quotes inside the text automatically.
selector = f"{element.tag_name}:text-is('{direct_text}')"
# Alternative: :text() - might be less strict about whitespace
# selector = f"{element.tag_name}:text('{direct_text}')"
# Let's try to validate this selector immediately if possible (costly)
# For now, return it optimistically.
return selector
# --- Fallbacks (Structure and Class) ---
base_selector = element.tag_name
stable_classes_used = []
# 7. Stable Class Names (Filter more strictly)
if 'class' in element.attributes and element.attributes['class']:
classes = element.attributes['class'].strip().split()
stable_classes = [
c for c in classes
if c and not c.isdigit() and
not re.search(r'\d', c) and # No digits at all
not re.match(r'.*(--|__|is-|has-|js-|active|selected|disabled|hidden).*', c, re.IGNORECASE) and # Avoid common states/modifiers/js
not re.match(r'^[a-zA-Z]{1,2}$', c) and # Avoid 1-2 letter classes (often layout helpers)
len(c) > 2 and len(c) < 30 # Reasonable length
]
if stable_classes:
stable_classes.sort()
stable_classes_used = stable_classes # Store for nth-of-type check
base_selector += '.' + '.'.join(escape_css(c) for c in stable_classes)
# --- Ancestor Context (Find nearest stable ancestor) ---
# Try to find a parent with ID or data-testid to anchor the selector
stable_ancestor_selector = None
current = element.parent
depth = 0
max_depth = 4 # How far up to look for an anchor
while current and depth < max_depth:
ancestor_selector_part = None
if 'id' in current.attributes and current.attributes['id']:
ancestor_id = current.attributes['id'].strip()
if ancestor_id and not ancestor_id.isdigit() and ' ' not in ancestor_id:
ancestor_selector_part = f"#{escape_css(ancestor_id)}"
elif not ancestor_selector_part: # Check testid only if ID not found
for test_attr in ['data-testid', 'data-test-id']:
if test_attr in current.attributes and current.attributes[test_attr]:
val = current.attributes[test_attr].strip()
if val:
ancestor_selector_part = f"[{test_attr}='{escape_css(val)}']"
break # Found one
# If we found a stable part for the ancestor, use it
if ancestor_selector_part:
stable_ancestor_selector = ancestor_selector_part
break # Stop searching up
current = current.parent
depth += 1
# Combine ancestor and base selector if ancestor found
final_selector = f"{stable_ancestor_selector} >> {base_selector}" if stable_ancestor_selector else base_selector
# 8. Add :nth-of-type ONLY if multiple siblings match the current selector AND no unique attribute/text was found
# This check becomes more complex with the ancestor path. We simplify here.
# Only add nth-of-type if we didn't find a unique ID/testid/name/text for the element itself.
needs_disambiguation = (stable_ancestor_selector is None) and \
(base_selector == element.tag_name or base_selector.startswith(element.tag_name + '.')) # Only tag or tag+class
if needs_disambiguation and element.parent:
try:
# Find siblings matching the base selector part (tag + potentially classes)
matching_siblings = []
for sib in element.parent.children:
if isinstance(sib, DOMElementNode) and sib.tag_name == element.tag_name:
# Check classes if they were used in the base selector
if stable_classes_used:
if DomService._check_classes_match(sib, stable_classes_used):
matching_siblings.append(sib)
else: # No classes used, just match tag
matching_siblings.append(sib)
if len(matching_siblings) > 1:
try:
index = matching_siblings.index(element) + 1
final_selector += f':nth-of-type({index})'
except ValueError:
logger.warning(f"Element not found in its own filtered sibling list for nth-of-type. Selector: {final_selector}")
except Exception as e:
logger.warning(f"Error during nth-of-type calculation: {e}. Selector: {final_selector}")
# 9. FINAL FALLBACK: Use original XPath if selector is still not specific
if final_selector == element.tag_name and element.xpath:
logger.warning(f"Selector for {element.tag_name} is just the tag. Falling back to XPath: {element.xpath}")
# Returning XPath directly might cause issues if executor expects CSS.
# Playwright can handle css=<xpath>, so let's return that.
return f"xpath={element.xpath}"
return final_selector
@staticmethod
def _check_classes_match(element: DOMElementNode, required_classes: List[str]) -> bool:
"""Helper to check if an element has all the required classes."""
if 'class' not in element.attributes or not element.attributes['class']:
return False
element_classes = set(element.attributes['class'].strip().split())
return all(req_class in element_classes for req_class in required_classes)
```