This is page 2 of 3. Use http://codebase.md/tokidoo/crawl4ai-rag-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .gitattributes
├── .gitignore
├── Caddyfile
├── crawled_pages.sql
├── docker-compose.yml
├── Dockerfile
├── knowledge_graphs
│ ├── ai_hallucination_detector.py
│ ├── ai_script_analyzer.py
│ ├── hallucination_reporter.py
│ ├── knowledge_graph_validator.py
│ ├── parse_repo_into_neo4j.py
│ ├── query_knowledge_graph.py
│ └── test_script.py
├── LICENSE
├── pyproject.toml
├── README.md
├── searxng
│ ├── limiter.toml
│ └── settings.yml
├── src
│ ├── crawl4ai_mcp.py
│ └── utils.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/knowledge_graphs/parse_repo_into_neo4j.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Direct Neo4j GitHub Code Repository Extractor
3 |
4 | Creates nodes and relationships directly in Neo4j without Graphiti:
5 | - File nodes
6 | - Class nodes
7 | - Method nodes
8 | - Function nodes
9 | - Import relationships
10 |
11 | Bypasses all LLM processing for maximum speed.
12 | """
13 |
14 | import asyncio
15 | import logging
16 | import os
17 | import subprocess
18 | import shutil
19 | from datetime import datetime, timezone
20 | from pathlib import Path
21 | from typing import List, Optional, Dict, Any, Set
22 | import ast
23 |
24 | from dotenv import load_dotenv
25 | from neo4j import AsyncGraphDatabase
26 |
27 | # Configure logging
28 | logging.basicConfig(
29 | level=logging.INFO,
30 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
31 | datefmt='%Y-%m-%d %H:%M:%S',
32 | )
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | class Neo4jCodeAnalyzer:
37 | """Analyzes code for direct Neo4j insertion"""
38 |
39 | def __init__(self):
40 | # External modules to ignore
41 | self.external_modules = {
42 | # Python standard library
43 | 'os', 'sys', 'json', 'logging', 'datetime', 'pathlib', 'typing', 'collections',
44 | 'asyncio', 'subprocess', 'ast', 're', 'string', 'urllib', 'http', 'email',
45 | 'time', 'uuid', 'hashlib', 'base64', 'itertools', 'functools', 'operator',
46 | 'contextlib', 'copy', 'pickle', 'tempfile', 'shutil', 'glob', 'fnmatch',
47 | 'io', 'codecs', 'locale', 'platform', 'socket', 'ssl', 'threading', 'queue',
48 | 'multiprocessing', 'concurrent', 'warnings', 'traceback', 'inspect',
49 | 'importlib', 'pkgutil', 'types', 'weakref', 'gc', 'dataclasses', 'enum',
50 | 'abc', 'numbers', 'decimal', 'fractions', 'math', 'cmath', 'random', 'statistics',
51 |
52 | # Common third-party libraries
53 | 'requests', 'urllib3', 'httpx', 'aiohttp', 'flask', 'django', 'fastapi',
54 | 'pydantic', 'sqlalchemy', 'alembic', 'psycopg2', 'pymongo', 'redis',
55 | 'celery', 'pytest', 'unittest', 'mock', 'faker', 'factory', 'hypothesis',
56 | 'numpy', 'pandas', 'matplotlib', 'seaborn', 'scipy', 'sklearn', 'torch',
57 | 'tensorflow', 'keras', 'opencv', 'pillow', 'boto3', 'botocore', 'azure',
58 | 'google', 'openai', 'anthropic', 'langchain', 'transformers', 'huggingface_hub',
59 | 'click', 'typer', 'rich', 'colorama', 'tqdm', 'python-dotenv', 'pyyaml',
60 | 'toml', 'configargparse', 'marshmallow', 'attrs', 'dataclasses-json',
61 | 'jsonschema', 'cerberus', 'voluptuous', 'schema', 'jinja2', 'mako',
62 | 'cryptography', 'bcrypt', 'passlib', 'jwt', 'authlib', 'oauthlib'
63 | }
64 |
65 | def analyze_python_file(self, file_path: Path, repo_root: Path, project_modules: Set[str]) -> Dict[str, Any]:
66 | """Extract structure for direct Neo4j insertion"""
67 | try:
68 | with open(file_path, 'r', encoding='utf-8') as f:
69 | content = f.read()
70 |
71 | tree = ast.parse(content)
72 | relative_path = str(file_path.relative_to(repo_root))
73 | module_name = self._get_importable_module_name(file_path, repo_root, relative_path)
74 |
75 | # Extract structure
76 | classes = []
77 | functions = []
78 | imports = []
79 |
80 | for node in ast.walk(tree):
81 | if isinstance(node, ast.ClassDef):
82 | # Extract class with its methods and attributes
83 | methods = []
84 | attributes = []
85 |
86 | for item in node.body:
87 | if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
88 | if not item.name.startswith('_'): # Public methods only
89 | # Extract comprehensive parameter info
90 | params = self._extract_function_parameters(item)
91 |
92 | # Get return type annotation
93 | return_type = self._get_name(item.returns) if item.returns else 'Any'
94 |
95 | # Create detailed parameter list for Neo4j storage
96 | params_detailed = []
97 | for p in params:
98 | param_str = f"{p['name']}:{p['type']}"
99 | if p['optional'] and p['default'] is not None:
100 | param_str += f"={p['default']}"
101 | elif p['optional']:
102 | param_str += "=None"
103 | if p['kind'] != 'positional':
104 | param_str = f"[{p['kind']}] {param_str}"
105 | params_detailed.append(param_str)
106 |
107 | methods.append({
108 | 'name': item.name,
109 | 'params': params, # Full parameter objects
110 | 'params_detailed': params_detailed, # Detailed string format
111 | 'return_type': return_type,
112 | 'args': [arg.arg for arg in item.args.args if arg.arg != 'self'] # Keep for backwards compatibility
113 | })
114 | elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
115 | # Type annotated attributes
116 | if not item.target.id.startswith('_'):
117 | attributes.append({
118 | 'name': item.target.id,
119 | 'type': self._get_name(item.annotation) if item.annotation else 'Any'
120 | })
121 |
122 | classes.append({
123 | 'name': node.name,
124 | 'full_name': f"{module_name}.{node.name}",
125 | 'methods': methods,
126 | 'attributes': attributes
127 | })
128 |
129 | elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
130 | # Only top-level functions
131 | if not any(node in cls_node.body for cls_node in ast.walk(tree) if isinstance(cls_node, ast.ClassDef)):
132 | if not node.name.startswith('_'):
133 | # Extract comprehensive parameter info
134 | params = self._extract_function_parameters(node)
135 |
136 | # Get return type annotation
137 | return_type = self._get_name(node.returns) if node.returns else 'Any'
138 |
139 | # Create detailed parameter list for Neo4j storage
140 | params_detailed = []
141 | for p in params:
142 | param_str = f"{p['name']}:{p['type']}"
143 | if p['optional'] and p['default'] is not None:
144 | param_str += f"={p['default']}"
145 | elif p['optional']:
146 | param_str += "=None"
147 | if p['kind'] != 'positional':
148 | param_str = f"[{p['kind']}] {param_str}"
149 | params_detailed.append(param_str)
150 |
151 | # Simple format for backwards compatibility
152 | params_list = [f"{p['name']}:{p['type']}" for p in params]
153 |
154 | functions.append({
155 | 'name': node.name,
156 | 'full_name': f"{module_name}.{node.name}",
157 | 'params': params, # Full parameter objects
158 | 'params_detailed': params_detailed, # Detailed string format
159 | 'params_list': params_list, # Simple string format for backwards compatibility
160 | 'return_type': return_type,
161 | 'args': [arg.arg for arg in node.args.args] # Keep for backwards compatibility
162 | })
163 |
164 | elif isinstance(node, (ast.Import, ast.ImportFrom)):
165 | # Track internal imports only
166 | if isinstance(node, ast.Import):
167 | for alias in node.names:
168 | if self._is_likely_internal(alias.name, project_modules):
169 | imports.append(alias.name)
170 | elif isinstance(node, ast.ImportFrom) and node.module:
171 | if (node.module.startswith('.') or self._is_likely_internal(node.module, project_modules)):
172 | imports.append(node.module)
173 |
174 | return {
175 | 'module_name': module_name,
176 | 'file_path': relative_path,
177 | 'classes': classes,
178 | 'functions': functions,
179 | 'imports': list(set(imports)), # Remove duplicates
180 | 'line_count': len(content.splitlines())
181 | }
182 |
183 | except Exception as e:
184 | logger.warning(f"Could not analyze {file_path}: {e}")
185 | return None
186 |
187 | def _is_likely_internal(self, import_name: str, project_modules: Set[str]) -> bool:
188 | """Check if an import is likely internal to the project"""
189 | if not import_name:
190 | return False
191 |
192 | # Relative imports are definitely internal
193 | if import_name.startswith('.'):
194 | return True
195 |
196 | # Check if it's a known external module
197 | base_module = import_name.split('.')[0]
198 | if base_module in self.external_modules:
199 | return False
200 |
201 | # Check if it matches any project module
202 | for project_module in project_modules:
203 | if import_name.startswith(project_module):
204 | return True
205 |
206 | # If it's not obviously external, consider it internal
207 | if (not any(ext in base_module.lower() for ext in ['test', 'mock', 'fake']) and
208 | not base_module.startswith('_') and
209 | len(base_module) > 2):
210 | return True
211 |
212 | return False
213 |
214 | def _get_importable_module_name(self, file_path: Path, repo_root: Path, relative_path: str) -> str:
215 | """Determine the actual importable module name for a Python file"""
216 | # Start with the default: convert file path to module path
217 | default_module = relative_path.replace('/', '.').replace('\\', '.').replace('.py', '')
218 |
219 | # Common patterns to detect the actual package root
220 | path_parts = Path(relative_path).parts
221 |
222 | # Look for common package indicators
223 | package_roots = []
224 |
225 | # Check each directory level for __init__.py to find package boundaries
226 | current_path = repo_root
227 | for i, part in enumerate(path_parts[:-1]): # Exclude the .py file itself
228 | current_path = current_path / part
229 | if (current_path / '__init__.py').exists():
230 | # This is a package directory, mark it as a potential root
231 | package_roots.append(i)
232 |
233 | if package_roots:
234 | # Use the first (outermost) package as the root
235 | package_start = package_roots[0]
236 | module_parts = path_parts[package_start:]
237 | module_name = '.'.join(module_parts).replace('.py', '')
238 | return module_name
239 |
240 | # Fallback: look for common Python project structures
241 | # Skip common non-package directories
242 | skip_dirs = {'src', 'lib', 'source', 'python', 'pkg', 'packages'}
243 |
244 | # Find the first directory that's not in skip_dirs
245 | filtered_parts = []
246 | for part in path_parts:
247 | if part.lower() not in skip_dirs or filtered_parts: # Once we start including, include everything
248 | filtered_parts.append(part)
249 |
250 | if filtered_parts:
251 | module_name = '.'.join(filtered_parts).replace('.py', '')
252 | return module_name
253 |
254 | # Final fallback: use the default
255 | return default_module
256 |
257 | def _extract_function_parameters(self, func_node):
258 | """Comprehensive parameter extraction from function definition"""
259 | params = []
260 |
261 | # Regular positional arguments
262 | for i, arg in enumerate(func_node.args.args):
263 | if arg.arg == 'self':
264 | continue
265 |
266 | param_info = {
267 | 'name': arg.arg,
268 | 'type': self._get_name(arg.annotation) if arg.annotation else 'Any',
269 | 'kind': 'positional',
270 | 'optional': False,
271 | 'default': None
272 | }
273 |
274 | # Check if this argument has a default value
275 | defaults_start = len(func_node.args.args) - len(func_node.args.defaults)
276 | if i >= defaults_start:
277 | default_idx = i - defaults_start
278 | if default_idx < len(func_node.args.defaults):
279 | param_info['optional'] = True
280 | param_info['default'] = self._get_default_value(func_node.args.defaults[default_idx])
281 |
282 | params.append(param_info)
283 |
284 | # *args parameter
285 | if func_node.args.vararg:
286 | params.append({
287 | 'name': f"*{func_node.args.vararg.arg}",
288 | 'type': self._get_name(func_node.args.vararg.annotation) if func_node.args.vararg.annotation else 'Any',
289 | 'kind': 'var_positional',
290 | 'optional': True,
291 | 'default': None
292 | })
293 |
294 | # Keyword-only arguments (after *)
295 | for i, arg in enumerate(func_node.args.kwonlyargs):
296 | param_info = {
297 | 'name': arg.arg,
298 | 'type': self._get_name(arg.annotation) if arg.annotation else 'Any',
299 | 'kind': 'keyword_only',
300 | 'optional': True, # All kwonly args are optional unless explicitly required
301 | 'default': None
302 | }
303 |
304 | # Check for default value
305 | if i < len(func_node.args.kw_defaults) and func_node.args.kw_defaults[i] is not None:
306 | param_info['default'] = self._get_default_value(func_node.args.kw_defaults[i])
307 | else:
308 | param_info['optional'] = False # No default = required kwonly arg
309 |
310 | params.append(param_info)
311 |
312 | # **kwargs parameter
313 | if func_node.args.kwarg:
314 | params.append({
315 | 'name': f"**{func_node.args.kwarg.arg}",
316 | 'type': self._get_name(func_node.args.kwarg.annotation) if func_node.args.kwarg.annotation else 'Dict[str, Any]',
317 | 'kind': 'var_keyword',
318 | 'optional': True,
319 | 'default': None
320 | })
321 |
322 | return params
323 |
324 | def _get_default_value(self, default_node):
325 | """Extract default value from AST node"""
326 | try:
327 | if isinstance(default_node, ast.Constant):
328 | return repr(default_node.value)
329 | elif isinstance(default_node, ast.Name):
330 | return default_node.id
331 | elif isinstance(default_node, ast.Attribute):
332 | return self._get_name(default_node)
333 | elif isinstance(default_node, ast.List):
334 | return "[]"
335 | elif isinstance(default_node, ast.Dict):
336 | return "{}"
337 | else:
338 | return "..."
339 | except Exception:
340 | return "..."
341 |
342 | def _get_name(self, node):
343 | """Extract name from AST node, handling complex types safely"""
344 | if node is None:
345 | return "Any"
346 |
347 | try:
348 | if isinstance(node, ast.Name):
349 | return node.id
350 | elif isinstance(node, ast.Attribute):
351 | if hasattr(node, 'value'):
352 | return f"{self._get_name(node.value)}.{node.attr}"
353 | else:
354 | return node.attr
355 | elif isinstance(node, ast.Subscript):
356 | # Handle List[Type], Dict[K,V], etc.
357 | base = self._get_name(node.value)
358 | if hasattr(node, 'slice'):
359 | if isinstance(node.slice, ast.Name):
360 | return f"{base}[{node.slice.id}]"
361 | elif isinstance(node.slice, ast.Tuple):
362 | elts = [self._get_name(elt) for elt in node.slice.elts]
363 | return f"{base}[{', '.join(elts)}]"
364 | elif isinstance(node.slice, ast.Constant):
365 | return f"{base}[{repr(node.slice.value)}]"
366 | elif isinstance(node.slice, ast.Attribute):
367 | return f"{base}[{self._get_name(node.slice)}]"
368 | elif isinstance(node.slice, ast.Subscript):
369 | return f"{base}[{self._get_name(node.slice)}]"
370 | else:
371 | # Try to get the name of the slice, fallback to Any if it fails
372 | try:
373 | slice_name = self._get_name(node.slice)
374 | return f"{base}[{slice_name}]"
375 | except:
376 | return f"{base}[Any]"
377 | return base
378 | elif isinstance(node, ast.Constant):
379 | return str(node.value)
380 | elif isinstance(node, ast.Str): # Python < 3.8
381 | return f'"{node.s}"'
382 | elif isinstance(node, ast.Tuple):
383 | elts = [self._get_name(elt) for elt in node.elts]
384 | return f"({', '.join(elts)})"
385 | elif isinstance(node, ast.List):
386 | elts = [self._get_name(elt) for elt in node.elts]
387 | return f"[{', '.join(elts)}]"
388 | else:
389 | # Fallback for complex types - return a simple string representation
390 | return "Any"
391 | except Exception:
392 | # If anything goes wrong, return a safe default
393 | return "Any"
394 |
395 |
396 | class DirectNeo4jExtractor:
397 | """Creates nodes and relationships directly in Neo4j"""
398 |
399 | def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
400 | self.neo4j_uri = neo4j_uri
401 | self.neo4j_user = neo4j_user
402 | self.neo4j_password = neo4j_password
403 | self.driver = None
404 | self.analyzer = Neo4jCodeAnalyzer()
405 |
406 | async def initialize(self):
407 | """Initialize Neo4j connection"""
408 | logger.info("Initializing Neo4j connection...")
409 | self.driver = AsyncGraphDatabase.driver(
410 | self.neo4j_uri,
411 | auth=(self.neo4j_user, self.neo4j_password)
412 | )
413 |
414 | # Clear existing data
415 | # logger.info("Clearing existing data...")
416 | # async with self.driver.session() as session:
417 | # await session.run("MATCH (n) DETACH DELETE n")
418 |
419 | # Create constraints and indexes
420 | logger.info("Creating constraints and indexes...")
421 | async with self.driver.session() as session:
422 | # Create constraints - using MERGE-friendly approach
423 | await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:File) REQUIRE f.path IS UNIQUE")
424 | await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Class) REQUIRE c.full_name IS UNIQUE")
425 | # Remove unique constraints for methods/attributes since they can be duplicated across classes
426 | # await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (m:Method) REQUIRE m.full_name IS UNIQUE")
427 | # await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:Function) REQUIRE f.full_name IS UNIQUE")
428 | # await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (a:Attribute) REQUIRE a.full_name IS UNIQUE")
429 |
430 | # Create indexes for performance
431 | await session.run("CREATE INDEX IF NOT EXISTS FOR (f:File) ON (f.name)")
432 | await session.run("CREATE INDEX IF NOT EXISTS FOR (c:Class) ON (c.name)")
433 | await session.run("CREATE INDEX IF NOT EXISTS FOR (m:Method) ON (m.name)")
434 |
435 | logger.info("Neo4j initialized successfully")
436 |
437 | async def clear_repository_data(self, repo_name: str):
438 | """Clear all data for a specific repository"""
439 | logger.info(f"Clearing existing data for repository: {repo_name}")
440 | async with self.driver.session() as session:
441 | # Delete in specific order to avoid constraint issues
442 |
443 | # 1. Delete methods and attributes (they depend on classes)
444 | await session.run("""
445 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
446 | DETACH DELETE m
447 | """, repo_name=repo_name)
448 |
449 | await session.run("""
450 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
451 | DETACH DELETE a
452 | """, repo_name=repo_name)
453 |
454 | # 2. Delete functions (they depend on files)
455 | await session.run("""
456 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
457 | DETACH DELETE func
458 | """, repo_name=repo_name)
459 |
460 | # 3. Delete classes (they depend on files)
461 | await session.run("""
462 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
463 | DETACH DELETE c
464 | """, repo_name=repo_name)
465 |
466 | # 4. Delete files (they depend on repository)
467 | await session.run("""
468 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)
469 | DETACH DELETE f
470 | """, repo_name=repo_name)
471 |
472 | # 5. Finally delete the repository
473 | await session.run("""
474 | MATCH (r:Repository {name: $repo_name})
475 | DETACH DELETE r
476 | """, repo_name=repo_name)
477 |
478 | logger.info(f"Cleared data for repository: {repo_name}")
479 |
480 | async def close(self):
481 | """Close Neo4j connection"""
482 | if self.driver:
483 | await self.driver.close()
484 |
485 | def clone_repo(self, repo_url: str, target_dir: str) -> str:
486 | """Clone repository with shallow clone"""
487 | logger.info(f"Cloning repository to: {target_dir}")
488 | if os.path.exists(target_dir):
489 | logger.info(f"Removing existing directory: {target_dir}")
490 | try:
491 | def handle_remove_readonly(func, path, exc):
492 | try:
493 | if os.path.exists(path):
494 | os.chmod(path, 0o777)
495 | func(path)
496 | except PermissionError:
497 | logger.warning(f"Could not remove {path} - file in use, skipping")
498 | pass
499 | shutil.rmtree(target_dir, onerror=handle_remove_readonly)
500 | except Exception as e:
501 | logger.warning(f"Could not fully remove {target_dir}: {e}. Proceeding anyway...")
502 |
503 | logger.info(f"Running git clone from {repo_url}")
504 | subprocess.run(['git', 'clone', '--depth', '1', repo_url, target_dir], check=True)
505 | logger.info("Repository cloned successfully")
506 | return target_dir
507 |
508 | def get_python_files(self, repo_path: str) -> List[Path]:
509 | """Get Python files, focusing on main source directories"""
510 | python_files = []
511 | exclude_dirs = {
512 | 'tests', 'test', '__pycache__', '.git', 'venv', 'env',
513 | 'node_modules', 'build', 'dist', '.pytest_cache', 'docs',
514 | 'examples', 'example', 'demo', 'benchmark'
515 | }
516 |
517 | for root, dirs, files in os.walk(repo_path):
518 | dirs[:] = [d for d in dirs if d not in exclude_dirs and not d.startswith('.')]
519 |
520 | for file in files:
521 | if file.endswith('.py') and not file.startswith('test_'):
522 | file_path = Path(root) / file
523 | if (file_path.stat().st_size < 500_000 and
524 | file not in ['setup.py', 'conftest.py']):
525 | python_files.append(file_path)
526 |
527 | return python_files
528 |
529 | async def analyze_repository(self, repo_url: str, temp_dir: str = None):
530 | """Analyze repository and create nodes/relationships in Neo4j"""
531 | repo_name = repo_url.split('/')[-1].replace('.git', '')
532 | logger.info(f"Analyzing repository: {repo_name}")
533 |
534 | # Clear existing data for this repository before re-processing
535 | await self.clear_repository_data(repo_name)
536 |
537 | # Set default temp_dir to repos folder at script level
538 | if temp_dir is None:
539 | script_dir = Path(__file__).parent
540 | temp_dir = str(script_dir / "repos" / repo_name)
541 |
542 | # Clone and analyze
543 | repo_path = Path(self.clone_repo(repo_url, temp_dir))
544 |
545 | try:
546 | logger.info("Getting Python files...")
547 | python_files = self.get_python_files(str(repo_path))
548 | logger.info(f"Found {len(python_files)} Python files to analyze")
549 |
550 | # First pass: identify project modules
551 | logger.info("Identifying project modules...")
552 | project_modules = set()
553 | for file_path in python_files:
554 | relative_path = str(file_path.relative_to(repo_path))
555 | module_parts = relative_path.replace('/', '.').replace('.py', '').split('.')
556 | if len(module_parts) > 0 and not module_parts[0].startswith('.'):
557 | project_modules.add(module_parts[0])
558 |
559 | logger.info(f"Identified project modules: {sorted(project_modules)}")
560 |
561 | # Second pass: analyze files and collect data
562 | logger.info("Analyzing Python files...")
563 | modules_data = []
564 | for i, file_path in enumerate(python_files):
565 | if i % 20 == 0:
566 | logger.info(f"Analyzing file {i+1}/{len(python_files)}: {file_path.name}")
567 |
568 | analysis = self.analyzer.analyze_python_file(file_path, repo_path, project_modules)
569 | if analysis:
570 | modules_data.append(analysis)
571 |
572 | logger.info(f"Found {len(modules_data)} files with content")
573 |
574 | # Create nodes and relationships in Neo4j
575 | logger.info("Creating nodes and relationships in Neo4j...")
576 | await self._create_graph(repo_name, modules_data)
577 |
578 | # Print summary
579 | total_classes = sum(len(mod['classes']) for mod in modules_data)
580 | total_methods = sum(len(cls['methods']) for mod in modules_data for cls in mod['classes'])
581 | total_functions = sum(len(mod['functions']) for mod in modules_data)
582 | total_imports = sum(len(mod['imports']) for mod in modules_data)
583 |
584 | print(f"\\n=== Direct Neo4j Repository Analysis for {repo_name} ===")
585 | print(f"Files processed: {len(modules_data)}")
586 | print(f"Classes created: {total_classes}")
587 | print(f"Methods created: {total_methods}")
588 | print(f"Functions created: {total_functions}")
589 | print(f"Import relationships: {total_imports}")
590 |
591 | logger.info(f"Successfully created Neo4j graph for {repo_name}")
592 |
593 | finally:
594 | if os.path.exists(temp_dir):
595 | logger.info(f"Cleaning up temporary directory: {temp_dir}")
596 | try:
597 | def handle_remove_readonly(func, path, exc):
598 | try:
599 | if os.path.exists(path):
600 | os.chmod(path, 0o777)
601 | func(path)
602 | except PermissionError:
603 | logger.warning(f"Could not remove {path} - file in use, skipping")
604 | pass
605 |
606 | shutil.rmtree(temp_dir, onerror=handle_remove_readonly)
607 | logger.info("Cleanup completed")
608 | except Exception as e:
609 | logger.warning(f"Cleanup failed: {e}. Directory may remain at {temp_dir}")
610 | # Don't fail the whole process due to cleanup issues
611 |
612 | async def _create_graph(self, repo_name: str, modules_data: List[Dict]):
613 | """Create all nodes and relationships in Neo4j"""
614 |
615 | async with self.driver.session() as session:
616 | # Create Repository node
617 | await session.run(
618 | "CREATE (r:Repository {name: $repo_name, created_at: datetime()})",
619 | repo_name=repo_name
620 | )
621 |
622 | nodes_created = 0
623 | relationships_created = 0
624 |
625 | for i, mod in enumerate(modules_data):
626 | # 1. Create File node
627 | await session.run("""
628 | CREATE (f:File {
629 | name: $name,
630 | path: $path,
631 | module_name: $module_name,
632 | line_count: $line_count,
633 | created_at: datetime()
634 | })
635 | """,
636 | name=mod['file_path'].split('/')[-1],
637 | path=mod['file_path'],
638 | module_name=mod['module_name'],
639 | line_count=mod['line_count']
640 | )
641 | nodes_created += 1
642 |
643 | # 2. Connect File to Repository
644 | await session.run("""
645 | MATCH (r:Repository {name: $repo_name})
646 | MATCH (f:File {path: $file_path})
647 | CREATE (r)-[:CONTAINS]->(f)
648 | """, repo_name=repo_name, file_path=mod['file_path'])
649 | relationships_created += 1
650 |
651 | # 3. Create Class nodes and relationships
652 | for cls in mod['classes']:
653 | # Create Class node using MERGE to avoid duplicates
654 | await session.run("""
655 | MERGE (c:Class {full_name: $full_name})
656 | ON CREATE SET c.name = $name, c.created_at = datetime()
657 | """, name=cls['name'], full_name=cls['full_name'])
658 | nodes_created += 1
659 |
660 | # Connect File to Class
661 | await session.run("""
662 | MATCH (f:File {path: $file_path})
663 | MATCH (c:Class {full_name: $class_full_name})
664 | MERGE (f)-[:DEFINES]->(c)
665 | """, file_path=mod['file_path'], class_full_name=cls['full_name'])
666 | relationships_created += 1
667 |
668 | # 4. Create Method nodes - use MERGE to avoid duplicates
669 | for method in cls['methods']:
670 | method_full_name = f"{cls['full_name']}.{method['name']}"
671 | # Create method with unique ID to avoid conflicts
672 | method_id = f"{cls['full_name']}::{method['name']}"
673 |
674 | await session.run("""
675 | MERGE (m:Method {method_id: $method_id})
676 | ON CREATE SET m.name = $name,
677 | m.full_name = $full_name,
678 | m.args = $args,
679 | m.params_list = $params_list,
680 | m.params_detailed = $params_detailed,
681 | m.return_type = $return_type,
682 | m.created_at = datetime()
683 | """,
684 | name=method['name'],
685 | full_name=method_full_name,
686 | method_id=method_id,
687 | args=method['args'],
688 | params_list=[f"{p['name']}:{p['type']}" for p in method['params']], # Simple format
689 | params_detailed=method.get('params_detailed', []), # Detailed format
690 | return_type=method['return_type']
691 | )
692 | nodes_created += 1
693 |
694 | # Connect Class to Method
695 | await session.run("""
696 | MATCH (c:Class {full_name: $class_full_name})
697 | MATCH (m:Method {method_id: $method_id})
698 | MERGE (c)-[:HAS_METHOD]->(m)
699 | """,
700 | class_full_name=cls['full_name'],
701 | method_id=method_id
702 | )
703 | relationships_created += 1
704 |
705 | # 5. Create Attribute nodes - use MERGE to avoid duplicates
706 | for attr in cls['attributes']:
707 | attr_full_name = f"{cls['full_name']}.{attr['name']}"
708 | # Create attribute with unique ID to avoid conflicts
709 | attr_id = f"{cls['full_name']}::{attr['name']}"
710 | await session.run("""
711 | MERGE (a:Attribute {attr_id: $attr_id})
712 | ON CREATE SET a.name = $name,
713 | a.full_name = $full_name,
714 | a.type = $type,
715 | a.created_at = datetime()
716 | """,
717 | name=attr['name'],
718 | full_name=attr_full_name,
719 | attr_id=attr_id,
720 | type=attr['type']
721 | )
722 | nodes_created += 1
723 |
724 | # Connect Class to Attribute
725 | await session.run("""
726 | MATCH (c:Class {full_name: $class_full_name})
727 | MATCH (a:Attribute {attr_id: $attr_id})
728 | MERGE (c)-[:HAS_ATTRIBUTE]->(a)
729 | """,
730 | class_full_name=cls['full_name'],
731 | attr_id=attr_id
732 | )
733 | relationships_created += 1
734 |
735 | # 6. Create Function nodes (top-level) - use MERGE to avoid duplicates
736 | for func in mod['functions']:
737 | func_id = f"{mod['file_path']}::{func['name']}"
738 | await session.run("""
739 | MERGE (f:Function {func_id: $func_id})
740 | ON CREATE SET f.name = $name,
741 | f.full_name = $full_name,
742 | f.args = $args,
743 | f.params_list = $params_list,
744 | f.params_detailed = $params_detailed,
745 | f.return_type = $return_type,
746 | f.created_at = datetime()
747 | """,
748 | name=func['name'],
749 | full_name=func['full_name'],
750 | func_id=func_id,
751 | args=func['args'],
752 | params_list=func.get('params_list', []), # Simple format for backwards compatibility
753 | params_detailed=func.get('params_detailed', []), # Detailed format
754 | return_type=func['return_type']
755 | )
756 | nodes_created += 1
757 |
758 | # Connect File to Function
759 | await session.run("""
760 | MATCH (file:File {path: $file_path})
761 | MATCH (func:Function {func_id: $func_id})
762 | MERGE (file)-[:DEFINES]->(func)
763 | """, file_path=mod['file_path'], func_id=func_id)
764 | relationships_created += 1
765 |
766 | # 7. Create Import relationships
767 | for import_name in mod['imports']:
768 | # Try to find the target file
769 | await session.run("""
770 | MATCH (source:File {path: $source_path})
771 | OPTIONAL MATCH (target:File)
772 | WHERE target.module_name = $import_name OR target.module_name STARTS WITH $import_name
773 | WITH source, target
774 | WHERE target IS NOT NULL
775 | MERGE (source)-[:IMPORTS]->(target)
776 | """, source_path=mod['file_path'], import_name=import_name)
777 | relationships_created += 1
778 |
779 | if (i + 1) % 10 == 0:
780 | logger.info(f"Processed {i + 1}/{len(modules_data)} files...")
781 |
782 | logger.info(f"Created {nodes_created} nodes and {relationships_created} relationships")
783 |
784 | async def search_graph(self, query_type: str, **kwargs):
785 | """Search the Neo4j graph directly"""
786 | async with self.driver.session() as session:
787 | if query_type == "files_importing":
788 | target = kwargs.get('target')
789 | result = await session.run("""
790 | MATCH (source:File)-[:IMPORTS]->(target:File)
791 | WHERE target.module_name CONTAINS $target
792 | RETURN source.path as file, target.module_name as imports
793 | """, target=target)
794 | return [{"file": record["file"], "imports": record["imports"]} async for record in result]
795 |
796 | elif query_type == "classes_in_file":
797 | file_path = kwargs.get('file_path')
798 | result = await session.run("""
799 | MATCH (f:File {path: $file_path})-[:DEFINES]->(c:Class)
800 | RETURN c.name as class_name, c.full_name as full_name
801 | """, file_path=file_path)
802 | return [{"class_name": record["class_name"], "full_name": record["full_name"]} async for record in result]
803 |
804 | elif query_type == "methods_of_class":
805 | class_name = kwargs.get('class_name')
806 | result = await session.run("""
807 | MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
808 | WHERE c.name CONTAINS $class_name OR c.full_name CONTAINS $class_name
809 | RETURN m.name as method_name, m.args as args
810 | """, class_name=class_name)
811 | return [{"method_name": record["method_name"], "args": record["args"]} async for record in result]
812 |
813 |
814 | async def main():
815 | """Example usage"""
816 | load_dotenv()
817 |
818 | neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
819 | neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
820 | neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
821 |
822 | extractor = DirectNeo4jExtractor(neo4j_uri, neo4j_user, neo4j_password)
823 |
824 | try:
825 | await extractor.initialize()
826 |
827 | # Analyze repository - direct Neo4j, no LLM processing!
828 | # repo_url = "https://github.com/pydantic/pydantic-ai.git"
829 | repo_url = "https://github.com/getzep/graphiti.git"
830 | await extractor.analyze_repository(repo_url)
831 |
832 | # Direct graph queries
833 | print("\\n=== Direct Neo4j Queries ===")
834 |
835 | # Which files import from models?
836 | results = await extractor.search_graph("files_importing", target="models")
837 | print(f"\\nFiles importing from 'models': {len(results)}")
838 | for result in results[:3]:
839 | print(f"- {result['file']} imports {result['imports']}")
840 |
841 | # What classes are in a specific file?
842 | results = await extractor.search_graph("classes_in_file", file_path="pydantic_ai/models/openai.py")
843 | print(f"\\nClasses in openai.py: {len(results)}")
844 | for result in results:
845 | print(f"- {result['class_name']}")
846 |
847 | # What methods does OpenAIModel have?
848 | results = await extractor.search_graph("methods_of_class", class_name="OpenAIModel")
849 | print(f"\\nMethods of OpenAIModel: {len(results)}")
850 | for result in results[:5]:
851 | print(f"- {result['method_name']}({', '.join(result['args'])})")
852 |
853 | finally:
854 | await extractor.close()
855 |
856 |
857 | if __name__ == "__main__":
858 | asyncio.run(main())
```
--------------------------------------------------------------------------------
/knowledge_graphs/knowledge_graph_validator.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Knowledge Graph Validator
3 |
4 | Validates AI-generated code against Neo4j knowledge graph containing
5 | repository information. Checks imports, methods, attributes, and parameters.
6 | """
7 |
8 | import asyncio
9 | import logging
10 | from typing import Dict, List, Optional, Set, Tuple, Any
11 | from dataclasses import dataclass, field
12 | from enum import Enum
13 | from neo4j import AsyncGraphDatabase
14 |
15 | from ai_script_analyzer import (
16 | AnalysisResult, ImportInfo, MethodCall, AttributeAccess,
17 | FunctionCall, ClassInstantiation
18 | )
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class ValidationStatus(Enum):
24 | VALID = "VALID"
25 | INVALID = "INVALID"
26 | UNCERTAIN = "UNCERTAIN"
27 | NOT_FOUND = "NOT_FOUND"
28 |
29 |
30 | @dataclass
31 | class ValidationResult:
32 | """Result of validating a single element"""
33 | status: ValidationStatus
34 | confidence: float # 0.0 to 1.0
35 | message: str
36 | details: Dict[str, Any] = field(default_factory=dict)
37 | suggestions: List[str] = field(default_factory=list)
38 |
39 |
40 | @dataclass
41 | class ImportValidation:
42 | """Validation result for an import"""
43 | import_info: ImportInfo
44 | validation: ValidationResult
45 | available_classes: List[str] = field(default_factory=list)
46 | available_functions: List[str] = field(default_factory=list)
47 |
48 |
49 | @dataclass
50 | class MethodValidation:
51 | """Validation result for a method call"""
52 | method_call: MethodCall
53 | validation: ValidationResult
54 | expected_params: List[str] = field(default_factory=list)
55 | actual_params: List[str] = field(default_factory=list)
56 | parameter_validation: ValidationResult = None
57 |
58 |
59 | @dataclass
60 | class AttributeValidation:
61 | """Validation result for attribute access"""
62 | attribute_access: AttributeAccess
63 | validation: ValidationResult
64 | expected_type: Optional[str] = None
65 |
66 |
67 | @dataclass
68 | class FunctionValidation:
69 | """Validation result for function call"""
70 | function_call: FunctionCall
71 | validation: ValidationResult
72 | expected_params: List[str] = field(default_factory=list)
73 | actual_params: List[str] = field(default_factory=list)
74 | parameter_validation: ValidationResult = None
75 |
76 |
77 | @dataclass
78 | class ClassValidation:
79 | """Validation result for class instantiation"""
80 | class_instantiation: ClassInstantiation
81 | validation: ValidationResult
82 | constructor_params: List[str] = field(default_factory=list)
83 | parameter_validation: ValidationResult = None
84 |
85 |
86 | @dataclass
87 | class ScriptValidationResult:
88 | """Complete validation results for a script"""
89 | script_path: str
90 | analysis_result: AnalysisResult
91 | import_validations: List[ImportValidation] = field(default_factory=list)
92 | class_validations: List[ClassValidation] = field(default_factory=list)
93 | method_validations: List[MethodValidation] = field(default_factory=list)
94 | attribute_validations: List[AttributeValidation] = field(default_factory=list)
95 | function_validations: List[FunctionValidation] = field(default_factory=list)
96 | overall_confidence: float = 0.0
97 | hallucinations_detected: List[Dict[str, Any]] = field(default_factory=list)
98 |
99 |
100 | class KnowledgeGraphValidator:
101 | """Validates code against Neo4j knowledge graph"""
102 |
103 | def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
104 | self.neo4j_uri = neo4j_uri
105 | self.neo4j_user = neo4j_user
106 | self.neo4j_password = neo4j_password
107 | self.driver = None
108 |
109 | # Cache for performance
110 | self.module_cache: Dict[str, List[str]] = {}
111 | self.class_cache: Dict[str, Dict[str, Any]] = {}
112 | self.method_cache: Dict[str, List[Dict[str, Any]]] = {}
113 | self.repo_cache: Dict[str, str] = {} # module_name -> repo_name
114 | self.knowledge_graph_modules: Set[str] = set() # Track modules in knowledge graph
115 |
116 | async def initialize(self):
117 | """Initialize Neo4j connection"""
118 | self.driver = AsyncGraphDatabase.driver(
119 | self.neo4j_uri,
120 | auth=(self.neo4j_user, self.neo4j_password)
121 | )
122 | logger.info("Knowledge graph validator initialized")
123 |
124 | async def close(self):
125 | """Close Neo4j connection"""
126 | if self.driver:
127 | await self.driver.close()
128 |
129 | async def validate_script(self, analysis_result: AnalysisResult) -> ScriptValidationResult:
130 | """Validate entire script analysis against knowledge graph"""
131 | result = ScriptValidationResult(
132 | script_path=analysis_result.file_path,
133 | analysis_result=analysis_result
134 | )
135 |
136 | # Validate imports first (builds context for other validations)
137 | result.import_validations = await self._validate_imports(analysis_result.imports)
138 |
139 | # Validate class instantiations
140 | result.class_validations = await self._validate_class_instantiations(
141 | analysis_result.class_instantiations
142 | )
143 |
144 | # Validate method calls
145 | result.method_validations = await self._validate_method_calls(
146 | analysis_result.method_calls
147 | )
148 |
149 | # Validate attribute accesses
150 | result.attribute_validations = await self._validate_attribute_accesses(
151 | analysis_result.attribute_accesses
152 | )
153 |
154 | # Validate function calls
155 | result.function_validations = await self._validate_function_calls(
156 | analysis_result.function_calls
157 | )
158 |
159 | # Calculate overall confidence and detect hallucinations
160 | result.overall_confidence = self._calculate_overall_confidence(result)
161 | result.hallucinations_detected = self._detect_hallucinations(result)
162 |
163 | return result
164 |
165 | async def _validate_imports(self, imports: List[ImportInfo]) -> List[ImportValidation]:
166 | """Validate all imports against knowledge graph"""
167 | validations = []
168 |
169 | for import_info in imports:
170 | validation = await self._validate_single_import(import_info)
171 | validations.append(validation)
172 |
173 | return validations
174 |
175 | async def _validate_single_import(self, import_info: ImportInfo) -> ImportValidation:
176 | """Validate a single import"""
177 | # Determine module to search for
178 | search_module = import_info.module if import_info.is_from_import else import_info.name
179 |
180 | # Check cache first
181 | if search_module in self.module_cache:
182 | available_files = self.module_cache[search_module]
183 | else:
184 | # Query Neo4j for matching modules
185 | available_files = await self._find_modules(search_module)
186 | self.module_cache[search_module] = available_files
187 |
188 | if available_files:
189 | # Get available classes and functions from the module
190 | classes, functions = await self._get_module_contents(search_module)
191 |
192 | # Track this module as being in the knowledge graph
193 | self.knowledge_graph_modules.add(search_module)
194 |
195 | # Also track the base module for "from X.Y.Z import ..." patterns
196 | if '.' in search_module:
197 | base_module = search_module.split('.')[0]
198 | self.knowledge_graph_modules.add(base_module)
199 |
200 | validation = ValidationResult(
201 | status=ValidationStatus.VALID,
202 | confidence=0.9,
203 | message=f"Module '{search_module}' found in knowledge graph",
204 | details={"matched_files": available_files, "in_knowledge_graph": True}
205 | )
206 |
207 | return ImportValidation(
208 | import_info=import_info,
209 | validation=validation,
210 | available_classes=classes,
211 | available_functions=functions
212 | )
213 | else:
214 | # External library - mark as such but don't treat as error
215 | validation = ValidationResult(
216 | status=ValidationStatus.UNCERTAIN,
217 | confidence=0.8, # High confidence it's external, not an error
218 | message=f"Module '{search_module}' is external (not in knowledge graph)",
219 | details={"could_be_external": True, "in_knowledge_graph": False}
220 | )
221 |
222 | return ImportValidation(
223 | import_info=import_info,
224 | validation=validation
225 | )
226 |
227 | async def _validate_class_instantiations(self, instantiations: List[ClassInstantiation]) -> List[ClassValidation]:
228 | """Validate class instantiations"""
229 | validations = []
230 |
231 | for instantiation in instantiations:
232 | validation = await self._validate_single_class_instantiation(instantiation)
233 | validations.append(validation)
234 |
235 | return validations
236 |
237 | async def _validate_single_class_instantiation(self, instantiation: ClassInstantiation) -> ClassValidation:
238 | """Validate a single class instantiation"""
239 | class_name = instantiation.full_class_name or instantiation.class_name
240 |
241 | # Skip validation for classes not from knowledge graph
242 | if not self._is_from_knowledge_graph(class_name):
243 | validation = ValidationResult(
244 | status=ValidationStatus.UNCERTAIN,
245 | confidence=0.8,
246 | message=f"Skipping validation: '{class_name}' is not from knowledge graph"
247 | )
248 | return ClassValidation(
249 | class_instantiation=instantiation,
250 | validation=validation
251 | )
252 |
253 | # Find class in knowledge graph
254 | class_info = await self._find_class(class_name)
255 |
256 | if not class_info:
257 | validation = ValidationResult(
258 | status=ValidationStatus.NOT_FOUND,
259 | confidence=0.2,
260 | message=f"Class '{class_name}' not found in knowledge graph"
261 | )
262 | return ClassValidation(
263 | class_instantiation=instantiation,
264 | validation=validation
265 | )
266 |
267 | # Check constructor parameters (look for __init__ method)
268 | init_method = await self._find_method(class_name, "__init__")
269 |
270 | if init_method:
271 | param_validation = self._validate_parameters(
272 | expected_params=init_method.get('params_list', []),
273 | provided_args=instantiation.args,
274 | provided_kwargs=instantiation.kwargs
275 | )
276 | else:
277 | param_validation = ValidationResult(
278 | status=ValidationStatus.UNCERTAIN,
279 | confidence=0.5,
280 | message="Constructor parameters not found"
281 | )
282 |
283 | # Use parameter validation result if it failed
284 | if param_validation.status == ValidationStatus.INVALID:
285 | validation = ValidationResult(
286 | status=ValidationStatus.INVALID,
287 | confidence=param_validation.confidence,
288 | message=f"Class '{class_name}' found but has invalid constructor parameters: {param_validation.message}",
289 | suggestions=param_validation.suggestions
290 | )
291 | else:
292 | validation = ValidationResult(
293 | status=ValidationStatus.VALID,
294 | confidence=0.8,
295 | message=f"Class '{class_name}' found in knowledge graph"
296 | )
297 |
298 | return ClassValidation(
299 | class_instantiation=instantiation,
300 | validation=validation,
301 | parameter_validation=param_validation
302 | )
303 |
304 | async def _validate_method_calls(self, method_calls: List[MethodCall]) -> List[MethodValidation]:
305 | """Validate method calls"""
306 | validations = []
307 |
308 | for method_call in method_calls:
309 | validation = await self._validate_single_method_call(method_call)
310 | validations.append(validation)
311 |
312 | return validations
313 |
314 | async def _validate_single_method_call(self, method_call: MethodCall) -> MethodValidation:
315 | """Validate a single method call"""
316 | class_type = method_call.object_type
317 |
318 | if not class_type:
319 | validation = ValidationResult(
320 | status=ValidationStatus.UNCERTAIN,
321 | confidence=0.3,
322 | message=f"Cannot determine object type for '{method_call.object_name}'"
323 | )
324 | return MethodValidation(
325 | method_call=method_call,
326 | validation=validation
327 | )
328 |
329 | # Skip validation for classes not from knowledge graph
330 | if not self._is_from_knowledge_graph(class_type):
331 | validation = ValidationResult(
332 | status=ValidationStatus.UNCERTAIN,
333 | confidence=0.8,
334 | message=f"Skipping validation: '{class_type}' is not from knowledge graph"
335 | )
336 | return MethodValidation(
337 | method_call=method_call,
338 | validation=validation
339 | )
340 |
341 | # Find method in knowledge graph
342 | method_info = await self._find_method(class_type, method_call.method_name)
343 |
344 | if not method_info:
345 | # Check for similar method names
346 | similar_methods = await self._find_similar_methods(class_type, method_call.method_name)
347 |
348 | validation = ValidationResult(
349 | status=ValidationStatus.NOT_FOUND,
350 | confidence=0.1,
351 | message=f"Method '{method_call.method_name}' not found on class '{class_type}'",
352 | suggestions=similar_methods
353 | )
354 | return MethodValidation(
355 | method_call=method_call,
356 | validation=validation
357 | )
358 |
359 | # Validate parameters
360 | expected_params = method_info.get('params_list', [])
361 | param_validation = self._validate_parameters(
362 | expected_params=expected_params,
363 | provided_args=method_call.args,
364 | provided_kwargs=method_call.kwargs
365 | )
366 |
367 | # Use parameter validation result if it failed
368 | if param_validation.status == ValidationStatus.INVALID:
369 | validation = ValidationResult(
370 | status=ValidationStatus.INVALID,
371 | confidence=param_validation.confidence,
372 | message=f"Method '{method_call.method_name}' found but has invalid parameters: {param_validation.message}",
373 | suggestions=param_validation.suggestions
374 | )
375 | else:
376 | validation = ValidationResult(
377 | status=ValidationStatus.VALID,
378 | confidence=0.9,
379 | message=f"Method '{method_call.method_name}' found on class '{class_type}'"
380 | )
381 |
382 | return MethodValidation(
383 | method_call=method_call,
384 | validation=validation,
385 | expected_params=expected_params,
386 | actual_params=method_call.args + list(method_call.kwargs.keys()),
387 | parameter_validation=param_validation
388 | )
389 |
390 | async def _validate_attribute_accesses(self, attribute_accesses: List[AttributeAccess]) -> List[AttributeValidation]:
391 | """Validate attribute accesses"""
392 | validations = []
393 |
394 | for attr_access in attribute_accesses:
395 | validation = await self._validate_single_attribute_access(attr_access)
396 | validations.append(validation)
397 |
398 | return validations
399 |
400 | async def _validate_single_attribute_access(self, attr_access: AttributeAccess) -> AttributeValidation:
401 | """Validate a single attribute access"""
402 | class_type = attr_access.object_type
403 |
404 | if not class_type:
405 | validation = ValidationResult(
406 | status=ValidationStatus.UNCERTAIN,
407 | confidence=0.3,
408 | message=f"Cannot determine object type for '{attr_access.object_name}'"
409 | )
410 | return AttributeValidation(
411 | attribute_access=attr_access,
412 | validation=validation
413 | )
414 |
415 | # Skip validation for classes not from knowledge graph
416 | if not self._is_from_knowledge_graph(class_type):
417 | validation = ValidationResult(
418 | status=ValidationStatus.UNCERTAIN,
419 | confidence=0.8,
420 | message=f"Skipping validation: '{class_type}' is not from knowledge graph"
421 | )
422 | return AttributeValidation(
423 | attribute_access=attr_access,
424 | validation=validation
425 | )
426 |
427 | # Find attribute in knowledge graph
428 | attr_info = await self._find_attribute(class_type, attr_access.attribute_name)
429 |
430 | if not attr_info:
431 | # If not found as attribute, check if it's a method (for decorators like @agent.tool)
432 | method_info = await self._find_method(class_type, attr_access.attribute_name)
433 |
434 | if method_info:
435 | validation = ValidationResult(
436 | status=ValidationStatus.VALID,
437 | confidence=0.8,
438 | message=f"'{attr_access.attribute_name}' found as method on class '{class_type}' (likely used as decorator)"
439 | )
440 | return AttributeValidation(
441 | attribute_access=attr_access,
442 | validation=validation,
443 | expected_type="method"
444 | )
445 |
446 | validation = ValidationResult(
447 | status=ValidationStatus.NOT_FOUND,
448 | confidence=0.2,
449 | message=f"'{attr_access.attribute_name}' not found on class '{class_type}'"
450 | )
451 | return AttributeValidation(
452 | attribute_access=attr_access,
453 | validation=validation
454 | )
455 |
456 | validation = ValidationResult(
457 | status=ValidationStatus.VALID,
458 | confidence=0.8,
459 | message=f"Attribute '{attr_access.attribute_name}' found on class '{class_type}'"
460 | )
461 |
462 | return AttributeValidation(
463 | attribute_access=attr_access,
464 | validation=validation,
465 | expected_type=attr_info.get('type')
466 | )
467 |
468 | async def _validate_function_calls(self, function_calls: List[FunctionCall]) -> List[FunctionValidation]:
469 | """Validate function calls"""
470 | validations = []
471 |
472 | for func_call in function_calls:
473 | validation = await self._validate_single_function_call(func_call)
474 | validations.append(validation)
475 |
476 | return validations
477 |
478 | async def _validate_single_function_call(self, func_call: FunctionCall) -> FunctionValidation:
479 | """Validate a single function call"""
480 | func_name = func_call.full_name or func_call.function_name
481 |
482 | # Skip validation for functions not from knowledge graph
483 | if func_call.full_name and not self._is_from_knowledge_graph(func_call.full_name):
484 | validation = ValidationResult(
485 | status=ValidationStatus.UNCERTAIN,
486 | confidence=0.8,
487 | message=f"Skipping validation: '{func_name}' is not from knowledge graph"
488 | )
489 | return FunctionValidation(
490 | function_call=func_call,
491 | validation=validation
492 | )
493 |
494 | # Find function in knowledge graph
495 | func_info = await self._find_function(func_name)
496 |
497 | if not func_info:
498 | validation = ValidationResult(
499 | status=ValidationStatus.NOT_FOUND,
500 | confidence=0.2,
501 | message=f"Function '{func_name}' not found in knowledge graph"
502 | )
503 | return FunctionValidation(
504 | function_call=func_call,
505 | validation=validation
506 | )
507 |
508 | # Validate parameters
509 | expected_params = func_info.get('params_list', [])
510 | param_validation = self._validate_parameters(
511 | expected_params=expected_params,
512 | provided_args=func_call.args,
513 | provided_kwargs=func_call.kwargs
514 | )
515 |
516 | # Use parameter validation result if it failed
517 | if param_validation.status == ValidationStatus.INVALID:
518 | validation = ValidationResult(
519 | status=ValidationStatus.INVALID,
520 | confidence=param_validation.confidence,
521 | message=f"Function '{func_name}' found but has invalid parameters: {param_validation.message}",
522 | suggestions=param_validation.suggestions
523 | )
524 | else:
525 | validation = ValidationResult(
526 | status=ValidationStatus.VALID,
527 | confidence=0.8,
528 | message=f"Function '{func_name}' found in knowledge graph"
529 | )
530 |
531 | return FunctionValidation(
532 | function_call=func_call,
533 | validation=validation,
534 | expected_params=expected_params,
535 | actual_params=func_call.args + list(func_call.kwargs.keys()),
536 | parameter_validation=param_validation
537 | )
538 |
539 | def _validate_parameters(self, expected_params: List[str], provided_args: List[str],
540 | provided_kwargs: Dict[str, str]) -> ValidationResult:
541 | """Validate function/method parameters with comprehensive support"""
542 | if not expected_params:
543 | return ValidationResult(
544 | status=ValidationStatus.UNCERTAIN,
545 | confidence=0.5,
546 | message="Parameter information not available"
547 | )
548 |
549 | # Parse expected parameters - handle detailed format
550 | required_positional = []
551 | optional_positional = []
552 | keyword_only_required = []
553 | keyword_only_optional = []
554 | has_varargs = False
555 | has_varkwargs = False
556 |
557 | for param in expected_params:
558 | # Handle detailed format: "[keyword_only] name:type=default" or "name:type"
559 | param_clean = param.strip()
560 |
561 | # Check for parameter kind prefix
562 | kind = 'positional'
563 | if param_clean.startswith('['):
564 | end_bracket = param_clean.find(']')
565 | if end_bracket > 0:
566 | kind = param_clean[1:end_bracket]
567 | param_clean = param_clean[end_bracket+1:].strip()
568 |
569 | # Check for varargs/varkwargs
570 | if param_clean.startswith('*') and not param_clean.startswith('**'):
571 | has_varargs = True
572 | continue
573 | elif param_clean.startswith('**'):
574 | has_varkwargs = True
575 | continue
576 |
577 | # Parse name and check if optional
578 | if ':' in param_clean:
579 | param_name = param_clean.split(':')[0]
580 | is_optional = '=' in param_clean
581 |
582 | if kind == 'keyword_only':
583 | if is_optional:
584 | keyword_only_optional.append(param_name)
585 | else:
586 | keyword_only_required.append(param_name)
587 | else: # positional
588 | if is_optional:
589 | optional_positional.append(param_name)
590 | else:
591 | required_positional.append(param_name)
592 |
593 | # Count provided parameters
594 | provided_positional_count = len(provided_args)
595 | provided_keyword_names = set(provided_kwargs.keys())
596 |
597 | # Validate positional arguments
598 | min_required_positional = len(required_positional)
599 | max_allowed_positional = len(required_positional) + len(optional_positional)
600 |
601 | if not has_varargs and provided_positional_count > max_allowed_positional:
602 | return ValidationResult(
603 | status=ValidationStatus.INVALID,
604 | confidence=0.8,
605 | message=f"Too many positional arguments: provided {provided_positional_count}, max allowed {max_allowed_positional}"
606 | )
607 |
608 | if provided_positional_count < min_required_positional:
609 | return ValidationResult(
610 | status=ValidationStatus.INVALID,
611 | confidence=0.8,
612 | message=f"Too few positional arguments: provided {provided_positional_count}, required {min_required_positional}"
613 | )
614 |
615 | # Validate keyword arguments
616 | all_valid_kwarg_names = set(required_positional + optional_positional + keyword_only_required + keyword_only_optional)
617 | invalid_kwargs = provided_keyword_names - all_valid_kwarg_names
618 |
619 | if invalid_kwargs and not has_varkwargs:
620 | return ValidationResult(
621 | status=ValidationStatus.INVALID,
622 | confidence=0.7,
623 | message=f"Invalid keyword arguments: {list(invalid_kwargs)}",
624 | suggestions=[f"Valid parameters: {list(all_valid_kwarg_names)}"]
625 | )
626 |
627 | # Check required keyword-only arguments
628 | missing_required_kwargs = set(keyword_only_required) - provided_keyword_names
629 | if missing_required_kwargs:
630 | return ValidationResult(
631 | status=ValidationStatus.INVALID,
632 | confidence=0.8,
633 | message=f"Missing required keyword arguments: {list(missing_required_kwargs)}"
634 | )
635 |
636 | return ValidationResult(
637 | status=ValidationStatus.VALID,
638 | confidence=0.9,
639 | message="Parameters are valid"
640 | )
641 |
642 | # Neo4j Query Methods
643 |
644 | async def _find_modules(self, module_name: str) -> List[str]:
645 | """Find repository matching the module name, then return its files"""
646 | async with self.driver.session() as session:
647 | # First, try to find files with module names that match or start with the search term
648 | module_query = """
649 | MATCH (r:Repository)-[:CONTAINS]->(f:File)
650 | WHERE f.module_name = $module_name
651 | OR f.module_name STARTS WITH $module_name + '.'
652 | OR split(f.module_name, '.')[0] = $module_name
653 | RETURN DISTINCT r.name as repo_name, count(f) as file_count
654 | ORDER BY file_count DESC
655 | LIMIT 5
656 | """
657 |
658 | result = await session.run(module_query, module_name=module_name)
659 | repos_from_modules = []
660 | async for record in result:
661 | repos_from_modules.append(record['repo_name'])
662 |
663 | # Also try repository name matching as fallback
664 | repo_query = """
665 | MATCH (r:Repository)
666 | WHERE toLower(r.name) = toLower($module_name)
667 | OR toLower(replace(r.name, '-', '_')) = toLower($module_name)
668 | OR toLower(replace(r.name, '_', '-')) = toLower($module_name)
669 | RETURN r.name as repo_name
670 | ORDER BY
671 | CASE
672 | WHEN toLower(r.name) = toLower($module_name) THEN 1
673 | WHEN toLower(replace(r.name, '-', '_')) = toLower($module_name) THEN 2
674 | WHEN toLower(replace(r.name, '_', '-')) = toLower($module_name) THEN 3
675 | END
676 | LIMIT 5
677 | """
678 |
679 | result = await session.run(repo_query, module_name=module_name)
680 | repos_from_names = []
681 | async for record in result:
682 | repos_from_names.append(record['repo_name'])
683 |
684 | # Combine results, prioritizing module-based matches
685 | all_repos = repos_from_modules + [r for r in repos_from_names if r not in repos_from_modules]
686 |
687 | if not all_repos:
688 | return []
689 |
690 | # Get files from the best matching repository
691 | best_repo = all_repos[0]
692 | files_query = """
693 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)
694 | RETURN f.path, f.module_name
695 | LIMIT 50
696 | """
697 |
698 | result = await session.run(files_query, repo_name=best_repo)
699 | files = []
700 | async for record in result:
701 | files.append(record['f.path'])
702 |
703 | return files
704 |
705 | async def _get_module_contents(self, module_name: str) -> Tuple[List[str], List[str]]:
706 | """Get classes and functions available in a repository matching the module name"""
707 | async with self.driver.session() as session:
708 | # First, try to find repository by module names in files
709 | module_query = """
710 | MATCH (r:Repository)-[:CONTAINS]->(f:File)
711 | WHERE f.module_name = $module_name
712 | OR f.module_name STARTS WITH $module_name + '.'
713 | OR split(f.module_name, '.')[0] = $module_name
714 | RETURN DISTINCT r.name as repo_name, count(f) as file_count
715 | ORDER BY file_count DESC
716 | LIMIT 1
717 | """
718 |
719 | result = await session.run(module_query, module_name=module_name)
720 | record = await result.single()
721 |
722 | if record:
723 | repo_name = record['repo_name']
724 | else:
725 | # Fallback to repository name matching
726 | repo_query = """
727 | MATCH (r:Repository)
728 | WHERE toLower(r.name) = toLower($module_name)
729 | OR toLower(replace(r.name, '-', '_')) = toLower($module_name)
730 | OR toLower(replace(r.name, '_', '-')) = toLower($module_name)
731 | RETURN r.name as repo_name
732 | ORDER BY
733 | CASE
734 | WHEN toLower(r.name) = toLower($module_name) THEN 1
735 | WHEN toLower(replace(r.name, '-', '_')) = toLower($module_name) THEN 2
736 | WHEN toLower(replace(r.name, '_', '-')) = toLower($module_name) THEN 3
737 | END
738 | LIMIT 1
739 | """
740 |
741 | result = await session.run(repo_query, module_name=module_name)
742 | record = await result.single()
743 |
744 | if not record:
745 | return [], []
746 |
747 | repo_name = record['repo_name']
748 |
749 | # Get classes from this repository
750 | class_query = """
751 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
752 | RETURN DISTINCT c.name as class_name
753 | """
754 |
755 | result = await session.run(class_query, repo_name=repo_name)
756 | classes = []
757 | async for record in result:
758 | classes.append(record['class_name'])
759 |
760 | # Get functions from this repository
761 | func_query = """
762 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
763 | RETURN DISTINCT func.name as function_name
764 | """
765 |
766 | result = await session.run(func_query, repo_name=repo_name)
767 | functions = []
768 | async for record in result:
769 | functions.append(record['function_name'])
770 |
771 | return classes, functions
772 |
773 | async def _find_repository_for_module(self, module_name: str) -> Optional[str]:
774 | """Find the repository name that matches a module name"""
775 | if module_name in self.repo_cache:
776 | return self.repo_cache[module_name]
777 |
778 | async with self.driver.session() as session:
779 | # First, try to find repository by module names in files
780 | module_query = """
781 | MATCH (r:Repository)-[:CONTAINS]->(f:File)
782 | WHERE f.module_name = $module_name
783 | OR f.module_name STARTS WITH $module_name + '.'
784 | OR split(f.module_name, '.')[0] = $module_name
785 | RETURN DISTINCT r.name as repo_name, count(f) as file_count
786 | ORDER BY file_count DESC
787 | LIMIT 1
788 | """
789 |
790 | result = await session.run(module_query, module_name=module_name)
791 | record = await result.single()
792 |
793 | if record:
794 | repo_name = record['repo_name']
795 | else:
796 | # Fallback to repository name matching
797 | query = """
798 | MATCH (r:Repository)
799 | WHERE toLower(r.name) = toLower($module_name)
800 | OR toLower(replace(r.name, '-', '_')) = toLower($module_name)
801 | OR toLower(replace(r.name, '_', '-')) = toLower($module_name)
802 | OR toLower(r.name) CONTAINS toLower($module_name)
803 | OR toLower($module_name) CONTAINS toLower(replace(r.name, '-', '_'))
804 | RETURN r.name as repo_name
805 | ORDER BY
806 | CASE
807 | WHEN toLower(r.name) = toLower($module_name) THEN 1
808 | WHEN toLower(replace(r.name, '-', '_')) = toLower($module_name) THEN 2
809 | ELSE 3
810 | END
811 | LIMIT 1
812 | """
813 |
814 | result = await session.run(query, module_name=module_name)
815 | record = await result.single()
816 |
817 | repo_name = record['repo_name'] if record else None
818 |
819 | self.repo_cache[module_name] = repo_name
820 | return repo_name
821 |
822 | async def _find_class(self, class_name: str) -> Optional[Dict[str, Any]]:
823 | """Find class information in knowledge graph"""
824 | async with self.driver.session() as session:
825 | # First try exact match
826 | query = """
827 | MATCH (c:Class)
828 | WHERE c.name = $class_name OR c.full_name = $class_name
829 | RETURN c.name as name, c.full_name as full_name
830 | LIMIT 1
831 | """
832 |
833 | result = await session.run(query, class_name=class_name)
834 | record = await result.single()
835 |
836 | if record:
837 | return {
838 | 'name': record['name'],
839 | 'full_name': record['full_name']
840 | }
841 |
842 | # If no exact match and class_name has dots, try repository-based search
843 | if '.' in class_name:
844 | parts = class_name.split('.')
845 | module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
846 | class_part = parts[-1] # e.g., "Agent"
847 |
848 | # Find repository for the module
849 | repo_name = await self._find_repository_for_module(module_part)
850 |
851 | if repo_name:
852 | # Search for class within this repository
853 | repo_query = """
854 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
855 | WHERE c.name = $class_name
856 | RETURN c.name as name, c.full_name as full_name
857 | LIMIT 1
858 | """
859 |
860 | result = await session.run(repo_query, repo_name=repo_name, class_name=class_part)
861 | record = await result.single()
862 |
863 | if record:
864 | return {
865 | 'name': record['name'],
866 | 'full_name': record['full_name']
867 | }
868 |
869 | return None
870 |
871 | async def _find_method(self, class_name: str, method_name: str) -> Optional[Dict[str, Any]]:
872 | """Find method information for a class"""
873 | cache_key = f"{class_name}.{method_name}"
874 | if cache_key in self.method_cache:
875 | methods = self.method_cache[cache_key]
876 | return methods[0] if methods else None
877 |
878 | async with self.driver.session() as session:
879 | # First try exact match
880 | query = """
881 | MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
882 | WHERE (c.name = $class_name OR c.full_name = $class_name)
883 | AND m.name = $method_name
884 | RETURN m.name as name, m.params_list as params_list, m.params_detailed as params_detailed,
885 | m.return_type as return_type, m.args as args
886 | LIMIT 1
887 | """
888 |
889 | result = await session.run(query, class_name=class_name, method_name=method_name)
890 | record = await result.single()
891 |
892 | if record:
893 | # Use detailed params if available, fall back to simple params
894 | params_to_use = record['params_detailed'] or record['params_list'] or []
895 |
896 | method_info = {
897 | 'name': record['name'],
898 | 'params_list': params_to_use,
899 | 'return_type': record['return_type'],
900 | 'args': record['args'] or []
901 | }
902 | self.method_cache[cache_key] = [method_info]
903 | return method_info
904 |
905 | # If no exact match and class_name has dots, try repository-based search
906 | if '.' in class_name:
907 | parts = class_name.split('.')
908 | module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
909 | class_part = parts[-1] # e.g., "Agent"
910 |
911 | # Find repository for the module
912 | repo_name = await self._find_repository_for_module(module_part)
913 |
914 | if repo_name:
915 | # Search for method within this repository's classes
916 | repo_query = """
917 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
918 | WHERE c.name = $class_name AND m.name = $method_name
919 | RETURN m.name as name, m.params_list as params_list, m.params_detailed as params_detailed,
920 | m.return_type as return_type, m.args as args
921 | LIMIT 1
922 | """
923 |
924 | result = await session.run(repo_query, repo_name=repo_name, class_name=class_part, method_name=method_name)
925 | record = await result.single()
926 |
927 | if record:
928 | # Use detailed params if available, fall back to simple params
929 | params_to_use = record['params_detailed'] or record['params_list'] or []
930 |
931 | method_info = {
932 | 'name': record['name'],
933 | 'params_list': params_to_use,
934 | 'return_type': record['return_type'],
935 | 'args': record['args'] or []
936 | }
937 | self.method_cache[cache_key] = [method_info]
938 | return method_info
939 |
940 | self.method_cache[cache_key] = []
941 | return None
942 |
943 | async def _find_attribute(self, class_name: str, attr_name: str) -> Optional[Dict[str, Any]]:
944 | """Find attribute information for a class"""
945 | async with self.driver.session() as session:
946 | # First try exact match
947 | query = """
948 | MATCH (c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
949 | WHERE (c.name = $class_name OR c.full_name = $class_name)
950 | AND a.name = $attr_name
951 | RETURN a.name as name, a.type as type
952 | LIMIT 1
953 | """
954 |
955 | result = await session.run(query, class_name=class_name, attr_name=attr_name)
956 | record = await result.single()
957 |
958 | if record:
959 | return {
960 | 'name': record['name'],
961 | 'type': record['type']
962 | }
963 |
964 | # If no exact match and class_name has dots, try repository-based search
965 | if '.' in class_name:
966 | parts = class_name.split('.')
967 | module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
968 | class_part = parts[-1] # e.g., "Agent"
969 |
970 | # Find repository for the module
971 | repo_name = await self._find_repository_for_module(module_part)
972 |
973 | if repo_name:
974 | # Search for attribute within this repository's classes
975 | repo_query = """
976 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
977 | WHERE c.name = $class_name AND a.name = $attr_name
978 | RETURN a.name as name, a.type as type
979 | LIMIT 1
980 | """
981 |
982 | result = await session.run(repo_query, repo_name=repo_name, class_name=class_part, attr_name=attr_name)
983 | record = await result.single()
984 |
985 | if record:
986 | return {
987 | 'name': record['name'],
988 | 'type': record['type']
989 | }
990 |
991 | return None
992 |
993 | async def _find_function(self, func_name: str) -> Optional[Dict[str, Any]]:
994 | """Find function information"""
995 | async with self.driver.session() as session:
996 | # First try exact match
997 | query = """
998 | MATCH (f:Function)
999 | WHERE f.name = $func_name OR f.full_name = $func_name
1000 | RETURN f.name as name, f.params_list as params_list, f.params_detailed as params_detailed,
1001 | f.return_type as return_type, f.args as args
1002 | LIMIT 1
1003 | """
1004 |
1005 | result = await session.run(query, func_name=func_name)
1006 | record = await result.single()
1007 |
1008 | if record:
1009 | # Use detailed params if available, fall back to simple params
1010 | params_to_use = record['params_detailed'] or record['params_list'] or []
1011 |
1012 | return {
1013 | 'name': record['name'],
1014 | 'params_list': params_to_use,
1015 | 'return_type': record['return_type'],
1016 | 'args': record['args'] or []
1017 | }
1018 |
1019 | # If no exact match and func_name has dots, try repository-based search
1020 | if '.' in func_name:
1021 | parts = func_name.split('.')
1022 | module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
1023 | func_part = parts[-1] # e.g., "some_function"
1024 |
1025 | # Find repository for the module
1026 | repo_name = await self._find_repository_for_module(module_part)
1027 |
1028 | if repo_name:
1029 | # Search for function within this repository
1030 | repo_query = """
1031 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
1032 | WHERE func.name = $func_name
1033 | RETURN func.name as name, func.params_list as params_list, func.params_detailed as params_detailed,
1034 | func.return_type as return_type, func.args as args
1035 | LIMIT 1
1036 | """
1037 |
1038 | result = await session.run(repo_query, repo_name=repo_name, func_name=func_part)
1039 | record = await result.single()
1040 |
1041 | if record:
1042 | # Use detailed params if available, fall back to simple params
1043 | params_to_use = record['params_detailed'] or record['params_list'] or []
1044 |
1045 | return {
1046 | 'name': record['name'],
1047 | 'params_list': params_to_use,
1048 | 'return_type': record['return_type'],
1049 | 'args': record['args'] or []
1050 | }
1051 |
1052 | return None
1053 |
1054 | async def _find_pydantic_ai_result_method(self, method_name: str) -> Optional[Dict[str, Any]]:
1055 | """Find method information for pydantic_ai result objects"""
1056 | # Look for methods on pydantic_ai classes that could be result objects
1057 | async with self.driver.session() as session:
1058 | # Search for common result methods in pydantic_ai repository
1059 | query = """
1060 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
1061 | WHERE m.name = $method_name
1062 | AND (c.name CONTAINS 'Result' OR c.name CONTAINS 'Stream' OR c.name CONTAINS 'Run')
1063 | RETURN m.name as name, m.params_list as params_list, m.params_detailed as params_detailed,
1064 | m.return_type as return_type, m.args as args, c.name as class_name
1065 | LIMIT 1
1066 | """
1067 |
1068 | result = await session.run(query, repo_name="pydantic_ai", method_name=method_name)
1069 | record = await result.single()
1070 |
1071 | if record:
1072 | # Use detailed params if available, fall back to simple params
1073 | params_to_use = record['params_detailed'] or record['params_list'] or []
1074 |
1075 | return {
1076 | 'name': record['name'],
1077 | 'params_list': params_to_use,
1078 | 'return_type': record['return_type'],
1079 | 'args': record['args'] or [],
1080 | 'source_class': record['class_name']
1081 | }
1082 |
1083 | return None
1084 |
1085 | async def _find_similar_modules(self, module_name: str) -> List[str]:
1086 | """Find similar repository names for suggestions"""
1087 | async with self.driver.session() as session:
1088 | query = """
1089 | MATCH (r:Repository)
1090 | WHERE toLower(r.name) CONTAINS toLower($partial_name)
1091 | OR toLower(replace(r.name, '-', '_')) CONTAINS toLower($partial_name)
1092 | OR toLower(replace(r.name, '_', '-')) CONTAINS toLower($partial_name)
1093 | RETURN r.name
1094 | LIMIT 5
1095 | """
1096 |
1097 | result = await session.run(query, partial_name=module_name[:3])
1098 | suggestions = []
1099 | async for record in result:
1100 | suggestions.append(record['name'])
1101 |
1102 | return suggestions
1103 |
1104 | async def _find_similar_methods(self, class_name: str, method_name: str) -> List[str]:
1105 | """Find similar method names for suggestions"""
1106 | async with self.driver.session() as session:
1107 | # First try exact class match
1108 | query = """
1109 | MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
1110 | WHERE (c.name = $class_name OR c.full_name = $class_name)
1111 | AND m.name CONTAINS $partial_name
1112 | RETURN m.name as name
1113 | LIMIT 5
1114 | """
1115 |
1116 | result = await session.run(query, class_name=class_name, partial_name=method_name[:3])
1117 | suggestions = []
1118 | async for record in result:
1119 | suggestions.append(record['name'])
1120 |
1121 | # If no suggestions and class_name has dots, try repository-based search
1122 | if not suggestions and '.' in class_name:
1123 | parts = class_name.split('.')
1124 | module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
1125 | class_part = parts[-1] # e.g., "Agent"
1126 |
1127 | # Find repository for the module
1128 | repo_name = await self._find_repository_for_module(module_part)
1129 |
1130 | if repo_name:
1131 | repo_query = """
1132 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
1133 | WHERE c.name = $class_name AND m.name CONTAINS $partial_name
1134 | RETURN m.name as name
1135 | LIMIT 5
1136 | """
1137 |
1138 | result = await session.run(repo_query, repo_name=repo_name, class_name=class_part, partial_name=method_name[:3])
1139 | async for record in result:
1140 | suggestions.append(record['name'])
1141 |
1142 | return suggestions
1143 |
1144 | def _calculate_overall_confidence(self, result: ScriptValidationResult) -> float:
1145 | """Calculate overall confidence score for the validation (knowledge graph items only)"""
1146 | kg_validations = []
1147 |
1148 | # Only count validations from knowledge graph imports
1149 | for val in result.import_validations:
1150 | if val.validation.details.get('in_knowledge_graph', False):
1151 | kg_validations.append(val.validation.confidence)
1152 |
1153 | # Only count validations from knowledge graph classes
1154 | for val in result.class_validations:
1155 | class_name = val.class_instantiation.full_class_name or val.class_instantiation.class_name
1156 | if self._is_from_knowledge_graph(class_name):
1157 | kg_validations.append(val.validation.confidence)
1158 |
1159 | # Only count validations from knowledge graph methods
1160 | for val in result.method_validations:
1161 | if val.method_call.object_type and self._is_from_knowledge_graph(val.method_call.object_type):
1162 | kg_validations.append(val.validation.confidence)
1163 |
1164 | # Only count validations from knowledge graph attributes
1165 | for val in result.attribute_validations:
1166 | if val.attribute_access.object_type and self._is_from_knowledge_graph(val.attribute_access.object_type):
1167 | kg_validations.append(val.validation.confidence)
1168 |
1169 | # Only count validations from knowledge graph functions
1170 | for val in result.function_validations:
1171 | if val.function_call.full_name and self._is_from_knowledge_graph(val.function_call.full_name):
1172 | kg_validations.append(val.validation.confidence)
1173 |
1174 | if not kg_validations:
1175 | return 1.0 # No knowledge graph items to validate = perfect confidence
1176 |
1177 | return sum(kg_validations) / len(kg_validations)
1178 |
1179 | def _is_from_knowledge_graph(self, class_type: str) -> bool:
1180 | """Check if a class type comes from a module in the knowledge graph"""
1181 | if not class_type:
1182 | return False
1183 |
1184 | # For dotted names like "pydantic_ai.Agent" or "pydantic_ai.StreamedRunResult", check the base module
1185 | if '.' in class_type:
1186 | base_module = class_type.split('.')[0]
1187 | # Exact match only - "pydantic" should not match "pydantic_ai"
1188 | return base_module in self.knowledge_graph_modules
1189 |
1190 | # For simple names, check if any knowledge graph module matches exactly
1191 | # Don't use substring matching to avoid "pydantic" matching "pydantic_ai"
1192 | return class_type in self.knowledge_graph_modules
1193 |
1194 | def _detect_hallucinations(self, result: ScriptValidationResult) -> List[Dict[str, Any]]:
1195 | """Detect and categorize hallucinations"""
1196 | hallucinations = []
1197 | reported_items = set() # Track reported items to avoid duplicates
1198 |
1199 | # Check method calls (only for knowledge graph classes)
1200 | for val in result.method_validations:
1201 | if (val.validation.status == ValidationStatus.NOT_FOUND and
1202 | val.method_call.object_type and
1203 | self._is_from_knowledge_graph(val.method_call.object_type)):
1204 |
1205 | # Create unique key to avoid duplicates
1206 | key = (val.method_call.line_number, val.method_call.method_name, val.method_call.object_type)
1207 | if key not in reported_items:
1208 | reported_items.add(key)
1209 | hallucinations.append({
1210 | 'type': 'METHOD_NOT_FOUND',
1211 | 'location': f"line {val.method_call.line_number}",
1212 | 'description': f"Method '{val.method_call.method_name}' not found on class '{val.method_call.object_type}'",
1213 | 'suggestion': val.validation.suggestions[0] if val.validation.suggestions else None
1214 | })
1215 |
1216 | # Check attributes (only for knowledge graph classes) - but skip if already reported as method
1217 | for val in result.attribute_validations:
1218 | if (val.validation.status == ValidationStatus.NOT_FOUND and
1219 | val.attribute_access.object_type and
1220 | self._is_from_knowledge_graph(val.attribute_access.object_type)):
1221 |
1222 | # Create unique key - if this was already reported as a method, skip it
1223 | key = (val.attribute_access.line_number, val.attribute_access.attribute_name, val.attribute_access.object_type)
1224 | if key not in reported_items:
1225 | reported_items.add(key)
1226 | hallucinations.append({
1227 | 'type': 'ATTRIBUTE_NOT_FOUND',
1228 | 'location': f"line {val.attribute_access.line_number}",
1229 | 'description': f"Attribute '{val.attribute_access.attribute_name}' not found on class '{val.attribute_access.object_type}'"
1230 | })
1231 |
1232 | # Check parameter issues (only for knowledge graph methods)
1233 | for val in result.method_validations:
1234 | if (val.parameter_validation and
1235 | val.parameter_validation.status == ValidationStatus.INVALID and
1236 | val.method_call.object_type and
1237 | self._is_from_knowledge_graph(val.method_call.object_type)):
1238 | hallucinations.append({
1239 | 'type': 'INVALID_PARAMETERS',
1240 | 'location': f"line {val.method_call.line_number}",
1241 | 'description': f"Invalid parameters for method '{val.method_call.method_name}': {val.parameter_validation.message}"
1242 | })
1243 |
1244 | return hallucinations
```