This is page 5 of 6. Use http://codebase.md/chillbruhhh/crawl4ai-mcp?page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .gitattributes
├── .gitignore
├── crawled_pages.sql
├── Dockerfile
├── knowledge_graphs
│ ├── ai_hallucination_detector.py
│ ├── ai_script_analyzer.py
│ ├── hallucination_reporter.py
│ ├── knowledge_graph_validator.py
│ ├── parse_repo_into_neo4j.py
│ ├── query_knowledge_graph.py
│ └── test_script.py
├── LICENSE
├── neo4j
│ └── docker-neo4j
│ ├── .github
│ │ └── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── .gitignore
│ ├── build-docker-image.sh
│ ├── build-utils-common-functions.sh
│ ├── COPYRIGHT
│ ├── DEVELOPMENT.md
│ ├── devenv
│ ├── devenv.local.template
│ ├── docker-image-src
│ │ ├── 2.3
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 3.0
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 3.1
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 3.2
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 3.3
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 3.4
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 3.5
│ │ │ ├── coredb
│ │ │ │ ├── docker-entrypoint.sh
│ │ │ │ ├── Dockerfile
│ │ │ │ └── neo4j-plugins.json
│ │ │ └── neo4j-admin
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 4.0
│ │ │ ├── coredb
│ │ │ │ ├── docker-entrypoint.sh
│ │ │ │ └── Dockerfile
│ │ │ └── neo4j-admin
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 4.1
│ │ │ ├── coredb
│ │ │ │ ├── docker-entrypoint.sh
│ │ │ │ └── Dockerfile
│ │ │ └── neo4j-admin
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 4.2
│ │ │ ├── coredb
│ │ │ │ ├── docker-entrypoint.sh
│ │ │ │ ├── Dockerfile
│ │ │ │ └── neo4j-plugins.json
│ │ │ └── neo4j-admin
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 4.3
│ │ │ ├── coredb
│ │ │ │ ├── docker-entrypoint.sh
│ │ │ │ ├── Dockerfile
│ │ │ │ └── neo4j-plugins.json
│ │ │ └── neo4j-admin
│ │ │ ├── docker-entrypoint.sh
│ │ │ └── Dockerfile
│ │ ├── 4.4
│ │ │ ├── coredb
│ │ │ │ ├── docker-entrypoint.sh
│ │ │ │ ├── Dockerfile-debian
│ │ │ │ ├── Dockerfile-ubi9
│ │ │ │ ├── neo4j-admin-report.sh
│ │ │ │ └── neo4j-plugins.json
│ │ │ └── neo4j-admin
│ │ │ ├── docker-entrypoint.sh
│ │ │ ├── Dockerfile-debian
│ │ │ └── Dockerfile-ubi9
│ │ ├── 5
│ │ │ ├── coredb
│ │ │ │ ├── docker-entrypoint.sh
│ │ │ │ ├── Dockerfile-debian
│ │ │ │ ├── Dockerfile-ubi8
│ │ │ │ ├── Dockerfile-ubi9
│ │ │ │ ├── neo4j-admin-report.sh
│ │ │ │ └── neo4j-plugins.json
│ │ │ └── neo4j-admin
│ │ │ ├── docker-entrypoint.sh
│ │ │ ├── Dockerfile-debian
│ │ │ ├── Dockerfile-ubi8
│ │ │ └── Dockerfile-ubi9
│ │ ├── calver
│ │ │ ├── coredb
│ │ │ │ ├── docker-entrypoint.sh
│ │ │ │ ├── Dockerfile-debian
│ │ │ │ ├── Dockerfile-ubi9
│ │ │ │ ├── neo4j-admin-report.sh
│ │ │ │ └── neo4j-plugins.json
│ │ │ └── neo4j-admin
│ │ │ ├── docker-entrypoint.sh
│ │ │ ├── Dockerfile-debian
│ │ │ └── Dockerfile-ubi9
│ │ └── common
│ │ ├── semver.jq
│ │ └── utilities.sh
│ ├── generate-stub-plugin
│ │ ├── build.gradle.kts
│ │ ├── Dockerfile
│ │ ├── ExampleNeo4jPlugin.java
│ │ ├── Makefile
│ │ ├── README.md
│ │ └── settings.gradle.kts
│ ├── LICENSE
│ ├── Makefile
│ ├── pom.xml
│ ├── publish-neo4j-admin-image.sh
│ ├── publish-neo4j-admin-images.sh
│ ├── README.md
│ └── src
│ ├── main
│ │ └── resources
│ │ └── log4j.properties
│ └── test
│ ├── java
│ │ └── com
│ │ └── neo4j
│ │ └── docker
│ │ ├── coredb
│ │ │ ├── configurations
│ │ │ │ ├── Configuration.java
│ │ │ │ ├── Setting.java
│ │ │ │ ├── TestConfSettings.java
│ │ │ │ ├── TestExtendedConf.java
│ │ │ │ └── TestJVMAdditionalConfig.java
│ │ │ ├── plugins
│ │ │ │ ├── Neo4jPluginEnv.java
│ │ │ │ ├── StubPluginHelper.java
│ │ │ │ ├── TestBundledPluginInstallation.java
│ │ │ │ ├── TestPluginInstallation.java
│ │ │ │ └── TestSemVerPluginMatching.java
│ │ │ ├── TestAdminReport.java
│ │ │ ├── TestAuthentication.java
│ │ │ ├── TestBasic.java
│ │ │ ├── TestCausalCluster.java
│ │ │ ├── TestMounting.java
│ │ │ └── TestUpgrade.java
│ │ ├── neo4jadmin
│ │ │ ├── TestAdminBasic.java
│ │ │ ├── TestBackupRestore.java
│ │ │ ├── TestBackupRestore44.java
│ │ │ ├── TestDumpLoad.java
│ │ │ ├── TestDumpLoad44.java
│ │ │ └── TestReport.java
│ │ ├── TestDeprecationWarning.java
│ │ ├── TestDockerComposeSecrets.java
│ │ └── utils
│ │ ├── DatabaseIO.java
│ │ ├── HostFileHttpHandler.java
│ │ ├── HttpServerTestExtension.java
│ │ ├── Neo4jVersion.java
│ │ ├── Neo4jVersionTest.java
│ │ ├── Network.java
│ │ ├── SetContainerUser.java
│ │ ├── TemporaryFolderManager.java
│ │ ├── TemporaryFolderManagerTest.java
│ │ ├── TestSettings.java
│ │ └── WaitStrategies.java
│ └── resources
│ ├── causal-cluster-compose.yml
│ ├── confs
│ │ ├── before50
│ │ │ ├── ConfsNotOverridden.conf
│ │ │ ├── ConfsReplaced.conf
│ │ │ ├── EnterpriseOnlyNotOverwritten.conf
│ │ │ ├── EnvVarsOverride.conf
│ │ │ ├── ExtendedConf.conf
│ │ │ ├── InvalidExtendedConf.conf
│ │ │ ├── JvmAdditionalNotOverridden.conf
│ │ │ ├── NoNewline.conf
│ │ │ └── ReadConf.conf
│ │ ├── ConfsNotOverridden.conf
│ │ ├── ConfsReplaced.conf
│ │ ├── EnterpriseOnlyNotOverwritten.conf
│ │ ├── EnvVarsOverride.conf
│ │ ├── ExtendedConf.conf
│ │ ├── InvalidExtendedConf.conf
│ │ ├── JvmAdditionalNotOverridden.conf
│ │ ├── NoNewline.conf
│ │ └── ReadConf.conf
│ ├── dockersecrets
│ │ ├── container-compose-with-incorrect-secrets.yml
│ │ ├── container-compose-with-secrets-override.yml
│ │ ├── container-compose-with-secrets.yml
│ │ ├── simple-container-compose-with-external-file-var.yml
│ │ └── simple-container-compose.yml
│ ├── ha-cluster-compose.yml
│ └── stubplugin
│ └── myPlugin.jar
├── pyproject.toml
├── README.md
├── src
│ ├── crawl4ai_mcp.py
│ └── utils.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/neo4j/docker-neo4j/docker-image-src/5/coredb/docker-entrypoint.sh:
--------------------------------------------------------------------------------
```bash
#!/bin/bash -eu
cmd="$1"
# load useful utility functions
. /startup/utilities.sh
function is_readable
{
# this code is fairly ugly but works no matter who this script is running as.
# It would be nice if the writability tests could use this logic somehow.
local _file=${1}
perm=$(stat -c %a "${_file}")
# everyone permission
if [[ ${perm:2:1} -ge 4 ]]; then
return 0
fi
# owner permissions
if [[ ${perm:0:1} -ge 4 ]]; then
if [[ "$(stat -c %U ${_file})" = "${userid}" ]] || [[ "$(stat -c %u ${_file})" = "${userid}" ]]; then
return 0
fi
fi
# group permissions
if [[ ${perm:1:1} -ge 4 ]]; then
if containsElement "$(stat -c %g ${_file})" "${groups[@]}" || containsElement "$(stat -c %G ${_file})" "${groups[@]}" ; then
return 0
fi
fi
return 1
}
function is_writable
{
# It would be nice if this and the is_readable function could combine somehow
local _file=${1}
perm=$(stat -c %a "${_file}")
# everyone permission
if containsElement ${perm:2:1} 2 3 6 7; then
return 0
fi
# owner permissions
if containsElement ${perm:0:1} 2 3 6 7; then
if [[ "$(stat -c %U ${_file})" = "${userid}" ]] || [[ "$(stat -c %u ${_file})" = "${userid}" ]]; then
return 0
fi
fi
# group permissions
if containsElement ${perm:1:1} 2 3 6 7; then
if containsElement "$(stat -c %g ${_file})" "${groups[@]}" || containsElement "$(stat -c %G ${_file})" "${groups[@]}" ; then
return 0
fi
fi
return 1
}
function check_mounted_folder_readable
{
local _directory=${1}
debug_msg "checking ${_directory} is readable"
if ! is_readable "${_directory}"; then
print_permissions_advice_and_fail "${_directory}" "${userid}" "${groupid}"
fi
}
function check_mounted_folder_writable_with_chown
{
# The /data and /log directory are a bit different because they are very likely to be mounted by the user but not
# necessarily writable.
# This depends on whether a user ID is passed to the container and which folders are mounted.
#
# No user ID passed to container:
# 1) No folders are mounted.
# The /data and /log folder are owned by neo4j by default, so should be writable already.
# 2) Both /log and /data are mounted.
# This means on start up, /data and /logs are owned by an unknown user and we should chown them to neo4j for
# backwards compatibility.
#
# User ID passed to container:
# 1) Both /data and /logs are mounted
# The /data and /logs folders are owned by an unknown user but we *should* have rw permission to them.
# That should be verified and error (helpfully) if not.
# 2) User mounts /data or /logs *but not both*
# The unmounted folder is still owned by neo4j, which should already be writable. The mounted folder should
# have rw permissions through user id. This should be verified.
# 3) No folders are mounted.
# The /data and /log folder are owned by neo4j by default, and these are already writable by the user.
# (This is a very unlikely use case).
local mountFolder=${1}
debug_msg "checking ${mountFolder} is writable"
if running_as_root && ! secure_mode_enabled; then
# check folder permissions
if ! is_writable "${mountFolder}" ; then
# warn that we're about to chown the folder and then chown it
echo "Warning: Folder mounted to \"${mountFolder}\" is not writable from inside container. Changing folder owner to ${userid}."
chown -R "${userid}":"${groupid}" "${mountFolder}"
# check permissions on files in the folder
elif [ $(su-exec "${userid}":"${groupid}" find "${mountFolder}" -not -writable | wc -l) -gt 0 ]; then
echo "Warning: Some files inside \"${mountFolder}\" are not writable from inside container. Changing folder owner to ${userid}."
chown -R "${userid}":"${groupid}" "${mountFolder}"
fi
else
if [[ ! -w "${mountFolder}" ]] && [[ "$(stat -c %U ${mountFolder})" != "neo4j" ]]; then
print_permissions_advice_and_fail "${mountFolder}" "${userid}" "${groupid}"
fi
fi
}
function load_plugin_from_location
{
# Install a plugin from location at runtime.
local _plugin_name="${1}"
local _location="${2}"
local _plugins_dir="${NEO4J_HOME}/plugins"
if [ -d /plugins ]; then
local _plugins_dir="/plugins"
fi
local _destination="${_plugins_dir}/${_plugin_name}.jar"
# Now we install the plugin that is shipped with Neo4j
for filename in ${_location}; do
echo "Installing Plugin '${_plugin_name}' from ${_location} to ${_destination}"
cp --preserve "${filename}" "${_destination}"
chmod +rw ${_destination}
done
if ! is_readable "${_destination}"; then
echo >&2 "Plugin at '${_destination}' is not readable"
exit 1
fi
}
function load_plugin_from_url
{
# Load a plugin at runtime. The provided github repository must have a versions.json on the master branch with the
# correct format.
local _plugin_name="${1}" #e.g. apoc, graph-algorithms, graph-ql
local _plugins_dir="${NEO4J_HOME}/plugins"
if [ -d /plugins ]; then
local _plugins_dir="/plugins"
fi
local _versions_json_url="$(jq --raw-output "with_entries( select(.key==\"${_plugin_name}\") ) | to_entries[] | .value.versions" /startup/neo4j-plugins.json )"
debug_msg "Will read ${_plugin_name} versions.json from ${_versions_json_url}"
# Using the same name for the plugin irrespective of version ensures we don't end up with different versions of the same plugin
local _destination="${_plugins_dir}/${_plugin_name}.jar"
local _neo4j_version="$(neo4j --version | cut -d' ' -f2)"
# Now we call out to github to get the versions.json for this plugin and we parse that to find the url for the correct plugin jar for our neo4j version
echo "Fetching versions.json for Plugin '${_plugin_name}' from ${_versions_json_url}"
local _versions_json
if ! _versions_json="$(wget -q --timeout 300 --tries 30 -O - "${_versions_json_url}")"; then
debug_msg "ERROR: could not fetch '${_versions_json}'"
echo >&2 "ERROR: could not query ${_versions_json_url} for plugin compatibility information.
This could indicate a problem with your network or this container's network settings.
Neo4j will continue to start, but \"${_plugin_name}\" will not be loaded."
return 1
fi
local _plugin_jar_url="$(echo "${_versions_json}" | jq -L/startup --raw-output "import \"semver\" as lib; [ .[] | select(.neo4j|lib::semver(\"${_neo4j_version}\")) ] | min_by(.neo4j) | .jar")"
if [[ -z "${_plugin_jar_url}" ]] || [[ "${_plugin_jar_url}" == "null" ]]; then
debug_msg "ERROR: '${_versions_json_url}' does not contain an entry for ${_neo4j_version}"
echo >&2 "ERROR: No compatible \"${_plugin_name}\" plugin found for Neo4j ${_neo4j_version} ${NEO4J_EDITION}.
This can happen with the newest Neo4j versions when a compatible plugin has not yet been released.
You can either use an older version of Neo4j, or continue without ${_plugin_name}.
Neo4j will continue to start, but \"${_plugin_name}\" will not be loaded."
else
echo "Installing Plugin '${_plugin_name}' from ${_plugin_jar_url} to ${_destination} "
wget -q --timeout 300 --tries 30 --output-document="${_destination}" "${_plugin_jar_url}"
if ! is_readable "${_destination}"; then
echo >&2 "Plugin at '${_destination}' is not readable"
exit 1
fi
fi
}
function apply_plugin_default_configuration
{
# Set the correct Load a plugin at runtime. The provided github repository must have a versions.json on the master branch with the
# correct format.
local _plugin_name="${1}" #e.g. apoc, graph-algorithms, graphql
local _reference_conf="${2}" # used to determine if we can override properties
local _neo4j_conf="${NEO4J_HOME}/conf/neo4j.conf"
local _property _value
echo "Applying default values for plugin ${_plugin_name} to neo4j.conf"
for _entry in $(jq --compact-output --raw-output "with_entries( select(.key==\"${_plugin_name}\") ) | to_entries[] | .value.properties | to_entries[]" /startup/neo4j-plugins.json); do
_property="$(jq --raw-output '.key' <<< "${_entry}")"
_value="$(jq --raw-output '.value' <<< "${_entry}")"
debug_msg "${_plugin_name} requires setting ${_property}=${_value}"
# the first grep strips out comments
if grep -o "^[^#]*" "${_reference_conf}" | grep -q --fixed-strings "${_property}=" ; then
# property is already set in the user provided config. In this case we don't override what has been set explicitly by the user.
echo "Skipping ${_property} for plugin ${_plugin_name} because it is already set."
echo "You may need to add ${_value} to the ${_property} setting in your configuration file."
else
if grep -o "^[^#]*" "${_neo4j_conf}" | grep -q --fixed-strings "${_property}=" ; then
sed --in-place "s/${_property}=/&${_value},/" "${_neo4j_conf}"
debug_msg "${_property} was already in the configuration file, so ${_value} was added to it."
else
echo -e "\n${_property}=${_value}" >> "${_neo4j_conf}"
debug_msg "${_property}=${_value} has been added to the configuration file."
fi
fi
done
}
function install_neo4j_plugins
{
# first verify that the requested plugins are valid.
debug_msg "One or more NEO4J_PLUGINS have been requested."
local _known_plugins=($(jq --raw-output "keys[]" /startup/neo4j-plugins.json))
debug_msg "Checking requested plugins are known and can be installed."
for plugin_name in $(echo "${NEO4J_PLUGINS}" | jq --raw-output '.[]'); do
if ! containsElement "${plugin_name}" "${_known_plugins[@]}"; then
printf >&2 "\"%s\" is not a known Neo4j plugin. Options are:\n%s" "${plugin_name}" "$(jq --raw-output "keys[1:][]" /startup/neo4j-plugins.json)"
exit 1
fi
done
# We store a copy of the config before we modify it for the plugins to allow us to see if there are user-set values in the input config that we shouldn't override
local _old_config="$(mktemp)"
if [ -e "${NEO4J_HOME}"/conf/neo4j.conf ]; then
cp "${NEO4J_HOME}"/conf/neo4j.conf "${_old_config}"
else
touch "${NEO4J_HOME}"/conf/neo4j.conf
touch "${_old_config}"
fi
for plugin_name in $(echo "${NEO4J_PLUGINS}" | jq --raw-output '.[]'); do
debug_msg "Plugin ${plugin_name} has been requested"
local _location="$(jq --raw-output "with_entries( select(.key==\"${plugin_name}\") ) | to_entries[] | .value.location" /startup/neo4j-plugins.json )"
if [ "${_location}" != "null" -a -n "$(shopt -s nullglob; echo ${_location})" ]; then
debug_msg "$plugin_name is already in the container at ${_location}"
load_plugin_from_location "${plugin_name}" "${_location}"
debug_msg "Applying plugin specific configurations."
apply_plugin_default_configuration "${plugin_name}" "${_old_config}"
else
debug_msg "$plugin_name must be downloaded."
if load_plugin_from_url "${plugin_name}"; then
debug_msg "Applying plugin specific configurations."
apply_plugin_default_configuration "${plugin_name}" "${_old_config}"
fi
fi
done
rm "${_old_config}"
}
function add_docker_default_to_conf
{
# docker defaults should NOT overwrite values already in the conf file
local _setting="${1}"
local _value="${2}"
if [ ! -e "${NEO4J_HOME}"/conf/neo4j.conf ] || ! grep -q "^${_setting}=" "${NEO4J_HOME}"/conf/neo4j.conf
then
debug_msg "Appended ${_setting}=${_value} to ${NEO4J_HOME}/conf/neo4j.conf"
echo -e "\n"${_setting}=${_value} >> "${NEO4J_HOME}"/conf/neo4j.conf
fi
}
function add_env_setting_to_conf
{
# settings from environment variables should overwrite values already in the conf
local _setting=${1}
local _value=${2}
local _conf_file
local _append_not_replace_configs=("server.jvm.additional")
# different settings need to go in different files now.
case "$(echo ${_setting} | cut -d . -f 1)" in
apoc)
_conf_file="${NEO4J_HOME}"/conf/apoc.conf
;;
*)
_conf_file="${NEO4J_HOME}"/conf/neo4j.conf
;;
esac
if [ -e "${_conf_file}" ] && grep -q -F "${_setting}=" "${_conf_file}"; then
if containsElement "${_setting}" "${_append_not_replace_configs[@]}"; then
debug_msg "${_setting} will be appended to ${_conf_file} without replacing existing settings."
else
# Remove any lines containing the setting already
debug_msg "Removing existing setting for ${_setting} in ${_conf_file}"
sed --in-place "/^${_setting}=.*/d" "${_conf_file}"
fi
fi
# Then always append setting to file
debug_msg "Appended ${_setting}=${_value} to ${_conf_file}"
echo "${_setting}=${_value}" >> "${_conf_file}"
}
function set_initial_password
{
local _neo4j_auth="${1}"
# set the neo4j initial password only if you run the database server
if [ "${cmd}" == "neo4j" ]; then
if [ "${_neo4j_auth:-}" == "none" ]; then
debug_msg "Authentication is requested to be unset"
add_env_setting_to_conf "dbms.security.auth_enabled" "false"
elif [[ "${_neo4j_auth:-}" =~ ^([^/]+)\/([^/]+)/?([tT][rR][uU][eE])?$ ]]; then
admin_user="${BASH_REMATCH[1]}"
password="${BASH_REMATCH[2]}"
do_reset="${BASH_REMATCH[3]}"
if [ "${password}" == "neo4j" ]; then
echo >&2 "Invalid value for password. It cannot be 'neo4j', which is the default."
exit 1
fi
if [ "${admin_user}" != "neo4j" ]; then
echo >&2 "Invalid admin username, it must be neo4j."
exit 1
fi
# this line has an inbuilt assumption that any configuration settings from the environment have already been applied to neo4j.conf
local _min_password_length=$(cat "${NEO4J_HOME}"/conf/neo4j.conf | grep dbms.security.auth_minimum_password_length | sed -E 's/.*=(.*)/\1/')
if [ "${#password}" -lt "${_min_password_length:-"8"}" ]; then
echo >&2 "Invalid value for password. The minimum password length is 8 characters.
If Neo4j fails to start, you can:
1) Use a stronger password.
2) Set configuration dbms.security.auth_minimum_password_length to override the minimum password length requirement.
3) Set environment variable NEO4J_dbms_security_auth__minimum__password__length to override the minimum password length requirement."
fi
if running_as_root; then
# running set-initial-password as root will create subfolders to /data as root, causing startup fail when neo4j can't read or write the /data/dbms folder
# creating the folder first will avoid that
mkdir -p /data/dbms
debug_msg "Making sure /data/dbms is owned by ${userid}:${groupid}"
chown "${userid}":"${groupid}" /data/dbms
fi
local extra_args=()
if [ "${do_reset}" == "true" ]; then
extra_args+=("--require-password-change")
fi
if [ "${EXTENDED_CONF+"yes"}" == "yes" ]; then
extra_args+=("--expand-commands")
fi
if debugging_enabled; then
extra_args+=("--verbose")
fi
debug_msg "Setting initial password"
debug_msg "${neo4j_admin_cmd} dbms set-initial-password ***** ${extra_args[*]}"
${neo4j_admin_cmd} dbms set-initial-password "${password}" "${extra_args[@]}"
elif [ -n "${_neo4j_auth:-}" ]; then
echo "$_neo4j_auth is invalid"
echo >&2 "Invalid value for NEO4J_AUTH: '${_neo4j_auth}'"
exit 1
fi
fi
}
# ==== CODE STARTS ====
debug_msg "DEBUGGING ENABLED"
# If we're running as root, then run as the neo4j user. Otherwise
# docker is running with --user and we simply use that user. Note
# that su-exec, despite its name, does not replicate the functionality
# of exec, so we need to use both
if running_as_root; then
userid="neo4j"
groupid="neo4j"
groups=($(id -G neo4j))
exec_cmd="exec su-exec neo4j:neo4j"
neo4j_admin_cmd="su-exec neo4j:neo4j neo4j-admin"
debug_msg "Running as root user inside neo4j image"
else
userid="$(id -u)"
groupid="$(id -g)"
groups=($(id -G))
exec_cmd="exec"
neo4j_admin_cmd="neo4j-admin"
debug_msg "Running as user ${userid}:${groupid} inside neo4j image"
fi
readonly userid
readonly groupid
readonly groups
readonly exec_cmd
readonly neo4j_admin_cmd
# Need to chown the home directory
if running_as_root; then
debug_msg "chowning ${NEO4J_HOME} recursively to ${userid}":"${groupid}"
chown -R "${userid}":"${groupid}" "${NEO4J_HOME}"
chmod 700 "${NEO4J_HOME}"
find "${NEO4J_HOME}" -mindepth 1 -maxdepth 1 -type d -exec chmod -R 700 {} \;
debug_msg "Setting all files in ${NEO4J_HOME}/conf to permissions 600"
find "${NEO4J_HOME}"/conf -type f -exec chmod -R 600 {} \;
fi
## == EXTRACT SECRETS FROM FILES ===
# These environment variables are set by using docker secrets and they override their equivalent env vars
# They are suffixed with _FILE and prefixed by the name of the env var they should override
# e.g. NEO4J_AUTH_FILE will override the value of the NEO4J_AUTH
# It's best to do this first so that the secrets are available for the rest of the script
for variable_name in $(printenv | awk -F= '{print $1}'); do
# Check if the variable ends with "_FILE" and starts with "NEO4J_"
if [[ $variable_name == *"_FILE" &&
$variable_name == "NEO4J_"* ]]; then
# Create a new variable name by removing the "_FILE" suffix
base_variable_name=${variable_name%_FILE}
# Get the value of the _FILE variable
secret_file_path="${!variable_name}"
if is_readable "${secret_file_path}"; then
# Read the secret value from the file
secret_value=$(<"$secret_file_path")
else
# File not readable
echo >&2 "The secret file '$secret_file_path' does not exist or is not readable. Make sure you have correctly configured docker secrets."
exit 1
fi
# Assign the value to the new variable
export "$base_variable_name"="$secret_value"
fi
done
# ==== CHECK LICENSE AGREEMENT ====
# Only prompt for license agreement if command contains "neo4j" in it
if [[ "${cmd}" == *"neo4j"* ]]; then
if [ "${NEO4J_EDITION}" == "enterprise" ]; then
: ${NEO4J_ACCEPT_LICENSE_AGREEMENT:="not accepted"}
if [[ "$NEO4J_ACCEPT_LICENSE_AGREEMENT" != "yes" && "$NEO4J_ACCEPT_LICENSE_AGREEMENT" != "eval" ]]; then
echo >&2 "
In order to use Neo4j Enterprise Edition you must accept the license agreement.
The license agreement is available at https://neo4j.com/terms/licensing/
If you have a support contract the following terms apply https://neo4j.com/terms/support-terms/
If you do not have a commercial license and want to evaluate the Software
please read the terms of the evaluation agreement before you accept.
https://neo4j.com/terms/enterprise_us/
(c) Neo4j Sweden AB. All Rights Reserved.
Use of this Software without a proper commercial license, or evaluation license
with Neo4j, Inc. or its affiliates is prohibited.
Neo4j has the right to terminate your usage if you are not compliant.
More information is also available at: https://neo4j.com/licensing/
If you have further inquiries about licensing, please contact us via https://neo4j.com/contact-us/
To accept the commercial license agreement set the environment variable
NEO4J_ACCEPT_LICENSE_AGREEMENT=yes
To accept the terms of the evaluation agreement set the environment variable
NEO4J_ACCEPT_LICENSE_AGREEMENT=eval
To do this you can use the following docker argument:
--env=NEO4J_ACCEPT_LICENSE_AGREEMENT=<yes|eval>
"
exit 1
fi
fi
fi
# NEO4JLABS_PLUGINS has been renamed to NEO4J_PLUGINS, but we want the old name to work for now.
if [ -n "${NEO4JLABS_PLUGINS:-}" ];
then
echo >&2 "NEO4JLABS_PLUGINS has been renamed to NEO4J_PLUGINS since Neo4j 5.0.0.
The old name will still work, but is likely to be deprecated in future releases."
: ${NEO4J_PLUGINS:=${NEO4JLABS_PLUGINS:-}}
fi
# ==== CHECK FILE PERMISSIONS ON MOUNTED FOLDERS ====
if [ -d /conf ]; then
check_mounted_folder_readable "/conf"
rm -rf "${NEO4J_HOME}"/conf/*
debug_msg "Copying contents of /conf to ${NEO4J_HOME}/conf/*"
find /conf -type f -exec cp --preserve=ownership,mode {} "${NEO4J_HOME}"/conf \;
fi
if [ -d /ssl ]; then
check_mounted_folder_readable "/ssl"
rm -rf "${NEO4J_HOME}"/certificates
ln -s /ssl "${NEO4J_HOME}"/certificates
fi
if [ -d /plugins ]; then
if [[ -n "${NEO4J_PLUGINS:-}" ]]; then
# We need write permissions to write the required plugins to /plugins
debug_msg "Extra plugins were requested. Ensuring the mounted /plugins folder has the required write permissions."
check_mounted_folder_writable_with_chown "/plugins"
fi
check_mounted_folder_readable "/plugins"
: ${NEO4J_server_directories_plugins:="/plugins"}
fi
if [ -d /import ]; then
check_mounted_folder_readable "/import"
: ${NEO4J_server_directories_import:="/import"}
fi
if [ -d /metrics ]; then
# metrics is enterprise only
if [ "${NEO4J_EDITION}" == "enterprise" ];
then
check_mounted_folder_writable_with_chown "/metrics"
: ${NEO4J_server_directories_metrics:="/metrics"}
fi
fi
if [ -d /logs ]; then
check_mounted_folder_writable_with_chown "/logs"
: ${NEO4J_server_directories_logs:="/logs"}
fi
if [ -d /data ]; then
check_mounted_folder_writable_with_chown "/data"
if [ -d /data/databases ]; then
check_mounted_folder_writable_with_chown "/data/databases"
fi
if [ -d /data/dbms ]; then
check_mounted_folder_writable_with_chown "/data/dbms"
fi
if [ -d /data/transactions ]; then
check_mounted_folder_writable_with_chown "/data/transactions"
fi
fi
if [ -d /licenses ]; then
check_mounted_folder_readable "/licenses"
: ${NEO4J_server_directories_licenses:="/licenses"}
fi
# ==== LOAD PLUGINS ====
if [[ -n "${NEO4J_PLUGINS:-}" ]]; then
# NEO4J_PLUGINS should be a json array of plugins like '["graph-algorithms", "apoc", "streams", "graphql"]'
install_neo4j_plugins
fi
# ==== RENAME LEGACY ENVIRONMENT CONF VARIABLES ====
# Env variable naming convention:
# - prefix NEO4J_
# - double underscore char '__' instead of single underscore '_' char in the setting name
# - underscore char '_' instead of dot '.' char in the setting name
# Example:
# NEO4J_server_tx__log_rotation_retention__policy env variable to set
# server.tx_log.rotation.retention_policy setting
# we only need to override the configurations with a docker specific override.
# The other config renames will be taken care of inside Neo4j.
: ${NEO4J_db_tx__log_rotation_retention__policy:=${NEO4J_dbms_tx__log_rotation_retention__policy:-}}
: ${NEO4J_server_memory_pagecache_size:=${NEO4J_dbms_memory_pagecache_size:-}}
: ${NEO4J_server_default__listen__address:=${NEO4J_dbms_default__listen__address:-}}
if [ "${NEO4J_EDITION}" == "enterprise" ];
then
: ${NEO4J_server_discovery_advertised__address:=${NEO4J_causal__clustering_discovery__advertised__address:-}}
: ${NEO4J_server_cluster_advertised__address:=${NEO4J_causal__clustering_transaction__advertised__address:-}}
: ${NEO4J_server_cluster_raft_advertised__address:=${NEO4J_causal__clustering_raft__advertised__address:-}}
fi
# ==== SET CONFIGURATIONS ====
## == DOCKER SPECIFIC DEFAULT CONFIGURATIONS ===
## these should not override *any* configurations set by the user
debug_msg "Setting docker specific configuration overrides"
add_docker_default_to_conf "server.memory.pagecache.size" "512M"
add_docker_default_to_conf "server.default_listen_address" "0.0.0.0"
# set enterprise only docker defaults
if [ "${NEO4J_EDITION}" == "enterprise" ];
then
debug_msg "Setting docker specific Enterprise Edition overrides"
add_docker_default_to_conf "server.discovery.advertised_address" "$(hostname):5000"
add_docker_default_to_conf "server.cluster.advertised_address" "$(hostname):6000"
add_docker_default_to_conf "server.cluster.raft.advertised_address" "$(hostname):7000"
add_docker_default_to_conf "server.routing.advertised_address" "$(hostname):7688"
fi
## == ENVIRONMENT VARIABLE CONFIGURATIONS ===
## these override BOTH defaults and any existing values in the neo4j.conf file
# these are docker control envs that have the NEO4J_ prefix but we don't want to add to the config.
not_configs=("NEO4J_ACCEPT_LICENSE_AGREEMENT" "NEO4J_AUTH" "NEO4J_AUTH_PATH" "NEO4J_DEBUG" "NEO4J_EDITION" \
"NEO4J_HOME" "NEO4J_PLUGINS" "NEO4J_SHA256" "NEO4J_TARBALL" "NEO4J_DEPRECATION_WARNING")
debug_msg "Applying configuration settings that have been set using environment variables."
# list env variables with prefix NEO4J_ and create settings from them
for i in $( set | grep ^NEO4J_ | awk -F'=' '{print $1}' | sort -rn ); do
if containsElement "$i" "${not_configs[@]}"; then
continue
fi
# Skip env variables with suffix _FILE, these are docker secrets
if [[ "$i" == *"_FILE" ]]; then
continue
fi
setting=$(echo "${i}" | sed 's|^NEO4J_||' | sed 's|_|.|g' | sed 's|\.\.|_|g')
value=$(echo "${!i}")
# Don't allow settings with no value or settings that start with a number (neo4j converts settings to env variables and you cannot have an env variable that starts with a number)
if [[ -n ${value} ]]; then
if [[ ! "${setting}" =~ ^[0-9]+.*$ ]]; then
add_env_setting_to_conf "${setting}" "${value}"
else
echo >&2 "WARNING: ${setting} not written to conf file. Settings that start with a number are not permitted."
fi
fi
done
# ==== SET PASSWORD ====
if [[ -n "${NEO4J_AUTH_PATH:-}" ]]; then
# Validate the existence of the password file
if [ ! -f "${NEO4J_AUTH_PATH}" ]; then
echo >&2 "The password file '${NEO4J_AUTH_PATH}' does not exist"
exit 1
fi
# validate the password file is readable
check_mounted_folder_readable "${NEO4J_AUTH_PATH}"
debug_msg "Setting initial password from file ${NEO4J_AUTH_PATH}"
set_initial_password "$(cat ${NEO4J_AUTH_PATH})"
else
debug_msg "Setting initial password from environment"
set_initial_password "${NEO4J_AUTH:-}"
fi
# ==== CLEANUP RUN FILE ====
if [ -f "${NEO4J_HOME}"/run/neo4j.pid ];
then
rm "${NEO4J_HOME}"/run/neo4j.pid
fi
# ==== INVOKE NEO4J STARTUP ====
[ -f "${EXTENSION_SCRIPT:-}" ] && . ${EXTENSION_SCRIPT}
if [ "${cmd}" == "dump-config" ]; then
if [ ! -d "/conf" ]; then
echo >&2 "You must mount a folder to /conf so that the configuration file(s) can be dumped to there."
exit 1
fi
check_mounted_folder_writable_with_chown "/conf"
cp --recursive "${NEO4J_HOME}"/conf/* /conf
echo "Config Dumped"
exit 0
fi
# this prints out a command for us to run.
# the command is something like: `java ...[lots of java options]... neo4j.mainClass ...[some neo4j options]...`
# putting debug messages here causes the function to break
function get_neo4j_run_cmd {
local extra_args=()
if [ "${EXTENDED_CONF+"yes"}" == "yes" ]; then
extra_args+=("--expand-commands")
fi
if running_as_root; then
su-exec neo4j:neo4j neo4j console --dry-run "${extra_args[@]}"
else
neo4j console --dry-run "${extra_args[@]}"
fi
}
if [ "${cmd}" == "neo4j" ]; then
# separate declaration and use of get_neo4j_run_cmd so that error codes are correctly surfaced
debug_msg "getting full neo4j run command"
neo4j_console_cmd="$(get_neo4j_run_cmd)"
debug_msg "${exec_cmd} ${neo4j_console_cmd}"
#%%DEPRECATION_WARNING_PLACEHOLDER%%
eval ${exec_cmd} ${neo4j_console_cmd?:No Neo4j command was generated}
else
debug_msg "${exec_cmd}" "$@"
${exec_cmd} "$@"
fi
```
--------------------------------------------------------------------------------
/neo4j/docker-neo4j/docker-image-src/4.4/coredb/docker-entrypoint.sh:
--------------------------------------------------------------------------------
```bash
#!/bin/bash -eu
cmd="$1"
# load useful utility functions
. /startup/utilities.sh
function is_readable
{
# this code is fairly ugly but works no matter who this script is running as.
# It would be nice if the writability tests could use this logic somehow.
local _file=${1}
perm=$(stat -c %a "${_file}")
# everyone permission
if [[ ${perm:2:1} -ge 4 ]]; then
return 0
fi
# owner permissions
if [[ ${perm:0:1} -ge 4 ]]; then
if [[ "$(stat -c %U ${_file})" = "${userid}" ]] || [[ "$(stat -c %u ${_file})" = "${userid}" ]]; then
return 0
fi
fi
# group permissions
if [[ ${perm:1:1} -ge 4 ]]; then
if containsElement "$(stat -c %g ${_file})" "${groups[@]}" || containsElement "$(stat -c %G ${_file})" "${groups[@]}" ; then
return 0
fi
fi
return 1
}
function is_writable
{
# It would be nice if this and the is_readable function could combine somehow
local _file=${1}
perm=$(stat -c %a "${_file}")
# everyone permission
if containsElement ${perm:2:1} 2 3 6 7; then
return 0
fi
# owner permissions
if containsElement ${perm:0:1} 2 3 6 7; then
if [[ "$(stat -c %U ${_file})" = "${userid}" ]] || [[ "$(stat -c %u ${_file})" = "${userid}" ]]; then
return 0
fi
fi
# group permissions
if containsElement ${perm:1:1} 2 3 6 7; then
if containsElement "$(stat -c %g ${_file})" "${groups[@]}" || containsElement "$(stat -c %G ${_file})" "${groups[@]}" ; then
return 0
fi
fi
return 1
}
function check_mounted_folder_readable
{
local _directory=${1}
debug_msg "checking ${_directory} is readable"
if ! is_readable "${_directory}"; then
print_permissions_advice_and_fail "${_directory}" "${userid}" "${groupid}"
fi
}
function check_mounted_folder_writable_with_chown
{
# The /data and /log directory are a bit different because they are very likely to be mounted by the user but not
# necessarily writable.
# This depends on whether a user ID is passed to the container and which folders are mounted.
#
# No user ID passed to container:
# 1) No folders are mounted.
# The /data and /log folder are owned by neo4j by default, so should be writable already.
# 2) Both /log and /data are mounted.
# This means on start up, /data and /logs are owned by an unknown user and we should chown them to neo4j for
# backwards compatibility.
#
# User ID passed to container:
# 1) Both /data and /logs are mounted
# The /data and /logs folders are owned by an unknown user but we *should* have rw permission to them.
# That should be verified and error (helpfully) if not.
# 2) User mounts /data or /logs *but not both*
# The unmounted folder is still owned by neo4j, which should already be writable. The mounted folder should
# have rw permissions through user id. This should be verified.
# 3) No folders are mounted.
# The /data and /log folder are owned by neo4j by default, and these are already writable by the user.
# (This is a very unlikely use case).
local mountFolder=${1}
debug_msg "checking ${mountFolder} is writable"
if running_as_root && ! secure_mode_enabled; then
# check folder permissions
if ! is_writable "${mountFolder}" ; then
# warn that we're about to chown the folder and then chown it
echo "Warning: Folder mounted to \"${mountFolder}\" is not writable from inside container. Changing folder owner to ${userid}."
chown -R "${userid}":"${groupid}" "${mountFolder}"
# check permissions on files in the folder
elif [ $(su-exec "${userid}":"${groupid}" find "${mountFolder}" -not -writable | wc -l) -gt 0 ]; then
echo "Warning: Some files inside \"${mountFolder}\" are not writable from inside container. Changing folder owner to ${userid}."
chown -R "${userid}":"${groupid}" "${mountFolder}"
fi
else
if [[ ! -w "${mountFolder}" ]] && [[ "$(stat -c %U ${mountFolder})" != "neo4j" ]]; then
print_permissions_advice_and_fail "${mountFolder}" "${userid}" "${groupid}"
fi
fi
}
function load_plugin_from_location
{
# Install a plugin from location at runtime.
local _plugin_name="${1}"
local _location="${2}"
local _plugins_dir="${NEO4J_HOME}/plugins"
if [ -d /plugins ]; then
local _plugins_dir="/plugins"
fi
local _destination="${_plugins_dir}/${_plugin_name}.jar"
# Now we install the plugin that is shipped with Neo4j
for filename in ${_location}; do
echo "Installing Plugin '${_plugin_name}' from ${_location} to ${_destination}"
cp --preserve "${filename}" "${_destination}"
chmod +rw ${_destination}
done
if ! is_readable "${_destination}"; then
echo >&2 "Plugin at '${_destination}' is not readable"
exit 1
fi
}
function load_plugin_from_url
{
# Load a plugin at runtime. The provided github repository must have a versions.json on the master branch with the
# correct format.
local _plugin_name="${1}" #e.g. apoc, graph-algorithms, graph-ql
local _plugins_dir="${NEO4J_HOME}/plugins"
if [ -d /plugins ]; then
local _plugins_dir="/plugins"
fi
local _versions_json_url="$(jq --raw-output "with_entries( select(.key==\"${_plugin_name}\") ) | to_entries[] | .value.versions" /startup/neo4j-plugins.json )"
debug_msg "Will read ${_plugin_name} versions.json from ${_versions_json_url}"
# Using the same name for the plugin irrespective of version ensures we don't end up with different versions of the same plugin
local _destination="${_plugins_dir}/${_plugin_name}.jar"
local _neo4j_version="$(neo4j --version | cut -d' ' -f2)"
# Now we call out to github to get the versions.json for this plugin and we parse that to find the url for the correct plugin jar for our neo4j version
echo "Fetching versions.json for Plugin '${_plugin_name}' from ${_versions_json_url}"
local _versions_json
if ! _versions_json="$(wget -q --timeout 300 --tries 30 -O - "${_versions_json_url}")"; then
debug_msg "ERROR: could not fetch '${_versions_json}'"
echo >&2 "ERROR: could not query ${_versions_json_url} for plugin compatibility information.
This could indicate a problem with your network or this container's network settings.
Neo4j will continue to start, but \"${_plugin_name}\" will not be loaded."
return 1
fi
local _plugin_jar_url="$(echo "${_versions_json}" | jq -L/startup --raw-output "import \"semver\" as lib; [ .[] | select(.neo4j|lib::semver(\"${_neo4j_version}\")) ] | min_by(.neo4j) | .jar")"
if [[ -z "${_plugin_jar_url}" ]] || [[ "${_plugin_jar_url}" == "null" ]]; then
debug_msg "ERROR: '${_versions_json_url}' does not contain an entry for ${_neo4j_version}"
echo >&2 "ERROR: No compatible \"${_plugin_name}\" plugin found for Neo4j ${_neo4j_version} ${NEO4J_EDITION}.
This can happen with the newest Neo4j versions when a compatible plugin has not yet been released.
You can either use an older version of Neo4j, or continue without ${_plugin_name}.
Neo4j will continue to start, but \"${_plugin_name}\" will not be loaded."
else
echo "Installing Plugin '${_plugin_name}' from ${_plugin_jar_url} to ${_destination} "
wget -q --timeout 300 --tries 30 --output-document="${_destination}" "${_plugin_jar_url}"
if ! is_readable "${_destination}"; then
echo >&2 "Plugin at '${_destination}' is not readable"
exit 1
fi
fi
}
function apply_plugin_default_configuration
{
# Set the correct Load a plugin at runtime. The provided github repository must have a versions.json on the master branch with the
# correct format.
local _plugin_name="${1}" #e.g. apoc, graph-algorithms, graphql
local _reference_conf="${2}" # used to determine if we can override properties
local _neo4j_conf="${NEO4J_HOME}/conf/neo4j.conf"
local _property _value
echo "Applying default values for plugin ${_plugin_name} to neo4j.conf"
for _entry in $(jq --compact-output --raw-output "with_entries( select(.key==\"${_plugin_name}\") ) | to_entries[] | .value.properties | to_entries[]" /startup/neo4j-plugins.json); do
_property="$(jq --raw-output '.key' <<< "${_entry}")"
_value="$(jq --raw-output '.value' <<< "${_entry}")"
debug_msg "${_plugin_name} requires setting ${_property}=${_value}"
# the first grep strips out comments
if grep -o "^[^#]*" "${_reference_conf}" | grep -q --fixed-strings "${_property}=" ; then
# property is already set in the user provided config. In this case we don't override what has been set explicitly by the user.
echo "Skipping ${_property} for plugin ${_plugin_name} because it is already set."
echo "You may need to add ${_value} to the ${_property} setting in your configuration file."
else
if grep -o "^[^#]*" "${_neo4j_conf}" | grep -q --fixed-strings "${_property}=" ; then
sed --in-place "s/${_property}=/&${_value},/" "${_neo4j_conf}"
debug_msg "${_property} was already in the configuration file, so ${_value} was added to it."
else
echo -e "\n${_property}=${_value}" >> "${_neo4j_conf}"
debug_msg "${_property}=${_value} has been added to the configuration file."
fi
fi
done
}
function install_neo4j_labs_plugins
{
# first verify that the requested plugins are valid.
debug_msg "One or more NEO4J_PLUGINS have been requested."
local _known_plugins=($(jq --raw-output "keys[]" /startup/neo4j-plugins.json))
debug_msg "Checking requested plugins are known and can be installed."
for plugin_name in $(echo "${NEO4J_PLUGINS}" | jq --raw-output '.[]'); do
if ! containsElement "${plugin_name}" "${_known_plugins[@]}"; then
printf >&2 "\"%s\" is not a known Neo4j plugin. Options are:\n%s" "${plugin_name}" "$(jq --raw-output "keys[1:][]" /startup/neo4j-plugins.json)"
exit 1
fi
done
# We store a copy of the config before we modify it for the plugins to allow us to see if there are user-set values in the input config that we shouldn't override
local _old_config="$(mktemp)"
if [ -e "${NEO4J_HOME}"/conf/neo4j.conf ]; then
cp "${NEO4J_HOME}"/conf/neo4j.conf "${_old_config}"
else
touch "${NEO4J_HOME}"/conf/neo4j.conf
touch "${_old_config}"
fi
for plugin_name in $(echo "${NEO4J_PLUGINS}" | jq --raw-output '.[]'); do
debug_msg "Plugin ${plugin_name} has been requested"
local _location="$(jq --raw-output "with_entries( select(.key==\"${plugin_name}\") ) | to_entries[] | .value.location" /startup/neo4j-plugins.json )"
if [ "${_location}" != "null" -a -n "$(shopt -s nullglob; echo ${_location})" ]; then
debug_msg "$plugin_name is already in the container at ${_location}"
load_plugin_from_location "${plugin_name}" "${_location}"
debug_msg "Applying plugin specific configurations."
apply_plugin_default_configuration "${plugin_name}" "${_old_config}"
else
debug_msg "$plugin_name must be downloaded."
if load_plugin_from_url "${plugin_name}"; then
debug_msg "Applying plugin specific configurations."
apply_plugin_default_configuration "${plugin_name}" "${_old_config}"
fi
fi
done
rm "${_old_config}"
}
function add_docker_default_to_conf
{
# docker defaults should NOT overwrite values already in the conf file
local _setting="${1}"
local _value="${2}"
if ! grep -q "^${_setting}=" "${NEO4J_HOME}"/conf/neo4j.conf
then
debug_msg "Appended ${_setting}=${_value} to ${NEO4J_HOME}/conf/neo4j.conf"
echo -e "\n"${_setting}=${_value} >> "${NEO4J_HOME}"/conf/neo4j.conf
fi
}
function add_env_setting_to_conf
{
# settings from environment variables should overwrite values already in the conf
local _setting=${1}
local _value=${2}
local _append_not_replace_configs=("dbms.jvm.additional")
if grep -q -F "${_setting}=" "${NEO4J_HOME}"/conf/neo4j.conf; then
if containsElement "${_setting}" "${_append_not_replace_configs[@]}"; then
debug_msg "${_setting} will be appended to neo4j.conf without replacing existing settings."
else
# Remove any lines containing the setting already
debug_msg "Removing existing setting for ${_setting}"
sed --in-place "/^${_setting}=.*/d" "${NEO4J_HOME}"/conf/neo4j.conf
fi
fi
# Then always append setting to file
debug_msg "Appended ${_setting}=${_value} to ${NEO4J_HOME}/conf/neo4j.conf"
echo "${_setting}=${_value}" >> "${NEO4J_HOME}"/conf/neo4j.conf
}
function set_initial_password
{
local _neo4j_auth="${1}"
# set the neo4j initial password only if you run the database server
if [ "${cmd}" == "neo4j" ]; then
if [ "${_neo4j_auth:-}" == "none" ]; then
debug_msg "Authentication is requested to be unset"
add_env_setting_to_conf "dbms.security.auth_enabled" "false"
elif [[ "${_neo4j_auth:-}" =~ ^([^/]+)\/([^/]+)/?([tT][rR][uU][eE])?$ ]]; then
admin_user="${BASH_REMATCH[1]}"
password="${BASH_REMATCH[2]}"
do_reset="${BASH_REMATCH[3]}"
if [ "${password}" == "neo4j" ]; then
echo >&2 "Invalid value for password. It cannot be 'neo4j', which is the default."
exit 1
fi
if [ "${admin_user}" != "neo4j" ]; then
echo >&2 "Invalid admin username, it must be neo4j."
exit 1
fi
if running_as_root; then
# running set-initial-password as root will create subfolders to /data as root, causing startup fail when neo4j can't read or write the /data/dbms folder
# creating the folder first will avoid that
mkdir -p /data/dbms
debug_msg "Making sure /data/dbms is owned by ${userid}:${groupid}"
chown "${userid}":"${groupid}" /data/dbms
fi
local extra_args=()
if [ "${do_reset}" == "true" ]; then
extra_args+=("--require-password-change")
fi
if [ "${EXTENDED_CONF+"yes"}" == "yes" ]; then
extra_args+=("--expand-commands")
fi
if debugging_enabled; then
extra_args+=("--verbose")
fi
debug_msg "Setting initial password"
debug_msg "${neo4j_admin_cmd} set-initial-password ***** ${extra_args[*]}"
if debugging_enabled; then
# don't suppress any output or errors in debugging mode
${neo4j_admin_cmd} set-initial-password "${password}" "${extra_args[@]}"
else
# Will exit with error if users already exist (and print a message explaining that)
# we probably don't want the message though, since it throws an error message on restarting the container.
${neo4j_admin_cmd} set-initial-password "${password}" "${extra_args[@]}" 2>/dev/null || true
fi
elif [ -n "${_neo4j_auth:-}" ]; then
echo "$_neo4j_auth is invalid"
echo >&2 "Invalid value for NEO4J_AUTH: '${_neo4j_auth}'"
exit 1
fi
fi
}
# ==== CODE STARTS ====
debug_msg "DEBUGGING ENABLED"
# If we're running as root, then run as the neo4j user. Otherwise
# docker is running with --user and we simply use that user. Note
# that su-exec, despite its name, does not replicate the functionality
# of exec, so we need to use both
if running_as_root; then
userid="neo4j"
groupid="neo4j"
groups=($(id -G neo4j))
exec_cmd="exec su-exec neo4j:neo4j"
neo4j_admin_cmd="su-exec neo4j:neo4j neo4j-admin"
debug_msg "Running as root user inside neo4j image"
else
userid="$(id -u)"
groupid="$(id -g)"
groups=($(id -G))
exec_cmd="exec"
neo4j_admin_cmd="neo4j-admin"
debug_msg "Running as user ${userid}:${groupid} inside neo4j image"
fi
readonly userid
readonly groupid
readonly groups
readonly exec_cmd
readonly neo4j_admin_cmd
# Need to chown the home directory
if running_as_root; then
debug_msg "chowning ${NEO4J_HOME} recursively to ${userid}":"${groupid}"
chown -R "${userid}":"${groupid}" "${NEO4J_HOME}"
chmod 700 "${NEO4J_HOME}"
find "${NEO4J_HOME}" -mindepth 1 -maxdepth 1 -type d -exec chmod -R 700 {} \;
debug_msg "Setting all files in ${NEO4J_HOME}/conf to permissions 600"
find "${NEO4J_HOME}"/conf -type f -exec chmod -R 600 {} \;
fi
## == EXTRACT SECRETS FROM FILES ===
# These environment variables are set by using docker secrets and they override their equivalent env vars
# They are suffixed with _FILE and prefixed by the name of the env var they should override
# e.g. NEO4J_AUTH_FILE will override the value of the NEO4J_AUTH
# It's best to do this first so that the secrets are available for the rest of the script
for variable_name in $(printenv | awk -F= '{print $1}'); do
# Check if the variable ends with "_FILE" and starts with "NEO4J_"
if [[ $variable_name == *"_FILE" &&
$variable_name == "NEO4J_"* ]]; then
# Create a new variable name by removing the "_FILE" suffix
base_variable_name=${variable_name%_FILE}
# Get the value of the _FILE variable
secret_file_path="${!variable_name}"
if is_readable "${secret_file_path}"; then
# Read the secret value from the file
secret_value=$(<"$secret_file_path")
else
# File not readable
echo >&2 "The secret file '$secret_file_path' does not exist or is not readable. Make sure you have correctly configured docker secrets."
exit 1
fi
# Assign the value to the new variable
export "$base_variable_name"="$secret_value"
fi
done
# ==== CHECK LICENSE AGREEMENT ====
# Only prompt for license agreement if command contains "neo4j" in it
if [[ "${cmd}" == *"neo4j"* ]]; then
if [ "${NEO4J_EDITION}" == "enterprise" ]; then
if [ "${NEO4J_ACCEPT_LICENSE_AGREEMENT:=no}" != "yes" ]; then
echo >&2 "
In order to use Neo4j Enterprise Edition you must accept the license agreement.
The license agreement is available at https://neo4j.com/terms/licensing/
If you have a support contract the following terms apply https://neo4j.com/terms/support-terms/
(c) Neo4j Sweden AB. All Rights Reserved.
Use of this Software without a proper commercial license
with Neo4j, Inc. or its affiliates is prohibited.
Neo4j has the right to terminate your usage if you are not compliant.
More information is also available at: https://neo4j.com/licensing/
If you have further inquiries about licensing, please contact us via https://neo4j.com/contact-us/
To accept the commercial license agreement set the environment variable
NEO4J_ACCEPT_LICENSE_AGREEMENT=yes
To do this you can use the following docker argument:
--env=NEO4J_ACCEPT_LICENSE_AGREEMENT=yes
"
exit 1
fi
fi
fi
# NEO4JLABS_PLUGINS is renamed to NEO4J_PLUGINS in 5.x, but we want the new name to work against 4.4 images too
if [ -n "${NEO4JLABS_PLUGINS:-}" ];
then
: ${NEO4J_PLUGINS:=${NEO4JLABS_PLUGINS:-}}
fi
# ==== RENAME LEGACY ENVIRONMENT CONF VARIABLES ====
# Env variable naming convention:
# - prefix NEO4J_
# - double underscore char '__' instead of single underscore '_' char in the setting name
# - underscore char '_' instead of dot '.' char in the setting name
# Example:
# NEO4J_dbms_tx__log_rotation_retention__policy env variable to set
# dbms.tx_log.rotation.retention_policy setting
# Backward compatibility - map old hardcoded env variables into new naming convention (if they aren't set already)
# Set some to default values if unset
: ${NEO4J_dbms_tx__log_rotation_retention__policy:=${NEO4J_dbms_txLog_rotation_retentionPolicy:-}}
: ${NEO4J_dbms_unmanaged__extension__classes:=${NEO4J_dbms_unmanagedExtensionClasses:-}}
: ${NEO4J_dbms_allow__format__migration:=${NEO4J_dbms_allowFormatMigration:-}}
: ${NEO4J_dbms_connectors_default__advertised__address:=${NEO4J_dbms_connectors_defaultAdvertisedAddress:-}}
if [ "${NEO4J_EDITION}" == "enterprise" ];
then
: ${NEO4J_causal__clustering_expected__core__cluster__size:=${NEO4J_causalClustering_expectedCoreClusterSize:-}}
: ${NEO4J_causal__clustering_initial__discovery__members:=${NEO4J_causalClustering_initialDiscoveryMembers:-}}
debug_msg "Copying contents of /conf to ${NEO4J_HOME}/conf/*"
: ${NEO4J_causal__clustering_discovery__advertised__address:=${NEO4J_causalClustering_discoveryAdvertisedAddress:-}}
: ${NEO4J_causal__clustering_transaction__advertised__address:=${NEO4J_causalClustering_transactionAdvertisedAddress:-}}
: ${NEO4J_causal__clustering_raft__advertised__address:=${NEO4J_causalClustering_raftAdvertisedAddress:-}}
fi
# unset old hardcoded unsupported env variables
unset NEO4J_dbms_txLog_rotation_retentionPolicy NEO4J_UDC_SOURCE \
NEO4J_dbms_unmanagedExtensionClasses NEO4J_dbms_allowFormatMigration \
NEO4J_dbms_connectors_defaultAdvertisedAddress NEO4J_ha_serverId \
NEO4J_ha_initialHosts NEO4J_causalClustering_expectedCoreClusterSize \
NEO4J_causalClustering_initialDiscoveryMembers \
NEO4J_causalClustering_discoveryListenAddress \
NEO4J_causalClustering_discoveryAdvertisedAddress \
NEO4J_causalClustering_transactionListenAddress \
NEO4J_causalClustering_transactionAdvertisedAddress \
NEO4J_causalClustering_raftListenAddress \
NEO4J_causalClustering_raftAdvertisedAddress
# ==== CHECK FILE PERMISSIONS ON MOUNTED FOLDERS ====
if [ -d /conf ]; then
check_mounted_folder_readable "/conf"
rm -rf "${NEO4J_HOME}"/conf/*
debug_msg "Copying contents of /conf to ${NEO4J_HOME}/conf/*"
find /conf -type f -exec cp --preserve=ownership,mode {} "${NEO4J_HOME}"/conf \;
fi
if [ -d /ssl ]; then
check_mounted_folder_readable "/ssl"
rm -rf "${NEO4J_HOME}"/certificates
ln -s /ssl "${NEO4J_HOME}"/certificates
fi
if [ -d /plugins ]; then
if [[ -n "${NEO4J_PLUGINS:-}" ]]; then
# We need write permissions to write the required plugins to /plugins
debug_msg "Extra plugins were requested. Ensuring the mounted /plugins folder has the required write permissions."
check_mounted_folder_writable_with_chown "/plugins"
fi
check_mounted_folder_readable "/plugins"
: ${NEO4J_dbms_directories_plugins:="/plugins"}
fi
if [ -d /import ]; then
check_mounted_folder_readable "/import"
: ${NEO4J_dbms_directories_import:="/import"}
fi
if [ -d /metrics ]; then
# metrics is enterprise only
if [ "${NEO4J_EDITION}" == "enterprise" ];
then
check_mounted_folder_writable_with_chown "/metrics"
: ${NEO4J_dbms_directories_metrics:="/metrics"}
fi
fi
if [ -d /logs ]; then
check_mounted_folder_writable_with_chown "/logs"
: ${NEO4J_dbms_directories_logs:="/logs"}
fi
if [ -d /data ]; then
check_mounted_folder_writable_with_chown "/data"
if [ -d /data/databases ]; then
check_mounted_folder_writable_with_chown "/data/databases"
fi
if [ -d /data/dbms ]; then
check_mounted_folder_writable_with_chown "/data/dbms"
fi
if [ -d /data/transactions ]; then
check_mounted_folder_writable_with_chown "/data/transactions"
fi
fi
if [ -d /licenses ]; then
check_mounted_folder_readable "/licenses"
: ${NEO4J_dbms_directories_licenses:="/licenses"}
fi
# ==== SET CONFIGURATIONS ====
## == DOCKER SPECIFIC DEFAULT CONFIGURATIONS ===
## these should not override *any* configurations set by the user
debug_msg "Setting docker specific configuration overrides"
add_docker_default_to_conf "dbms.memory.pagecache.size" "512M"
add_docker_default_to_conf "dbms.default_listen_address" "0.0.0.0"
# set enterprise only docker defaults
if [ "${NEO4J_EDITION}" == "enterprise" ];
then
debug_msg "Setting docker specific Enterprise Edition overrides"
add_docker_default_to_conf "causal_clustering.discovery_advertised_address" "$(hostname):5000"
add_docker_default_to_conf "causal_clustering.transaction_advertised_address" "$(hostname):6000"
add_docker_default_to_conf "causal_clustering.raft_advertised_address" "$(hostname):7000"
add_docker_default_to_conf "dbms.routing.advertised_address" "$(hostname):7688"
fi
## == ENVIRONMENT VARIABLE CONFIGURATIONS ===
## these override BOTH defaults and any existing values in the neo4j.conf file
# these are docker control envs that have the NEO4J_ prefix but we don't want to add to the config.
not_configs=("NEO4J_ACCEPT_LICENSE_AGREEMENT" "NEO4J_AUTH" "NEO4J_AUTH_PATH" "NEO4J_DEBUG" "NEO4J_EDITION" \
"NEO4J_HOME" "NEO4J_PLUGINS" "NEO4J_SHA256" "NEO4J_TARBALL")
debug_msg "Applying configuration settings that have been set using environment variables."
# list env variables with prefix NEO4J_ and create settings from them
for i in $( set | grep ^NEO4J_ | awk -F'=' '{print $1}' | sort -rn ); do
if containsElement "$i" "${not_configs[@]}"; then
continue
fi
# Skip env variables with suffix _FILE, these are docker secrets
if [[ "$i" == *"_FILE" ]]; then
continue
fi
setting=$(echo "${i}" | sed 's|^NEO4J_||' | sed 's|_|.|g' | sed 's|\.\.|_|g')
value=$(echo "${!i}")
# Don't allow settings with no value or settings that start with a number (neo4j converts settings to env variables and you cannot have an env variable that starts with a number)
if [[ -n ${value} ]]; then
if [[ ! "${setting}" =~ ^[0-9]+.*$ ]]; then
add_env_setting_to_conf "${setting}" "${value}"
else
echo >&2 "WARNING: ${setting} not written to conf file. Settings that start with a number are not permitted."
fi
fi
done
# ==== SET PASSWORD AND PLUGINS ====
if [[ -n "${NEO4J_AUTH_PATH:-}" ]]; then
# Validate the existence of the password file
if [ ! -f "${NEO4J_AUTH_PATH}" ]; then
echo >&2 "The password file '${NEO4J_AUTH_PATH}' does not exist"
exit 1
fi
# validate the password file is readable
check_mounted_folder_readable "${NEO4J_AUTH_PATH}"
debug_msg "Setting initial password from file ${NEO4J_AUTH_PATH}"
set_initial_password "$(cat ${NEO4J_AUTH_PATH})"
else
debug_msg "Setting initial password from environment"
set_initial_password "${NEO4J_AUTH:-}"
fi
if [[ ! -z "${NEO4J_PLUGINS:-}" ]]; then
# NEO4J_PLUGINS should be a json array of plugins like '["graph-algorithms", "apoc", "streams", "graphql"]'
install_neo4j_labs_plugins
fi
# ==== CLEANUP RUN FILE ====
if [ -f "${NEO4J_HOME}"/run/neo4j.pid ];
then
rm "${NEO4J_HOME}"/run/neo4j.pid
fi
# ==== INVOKE NEO4J STARTUP ====
[ -f "${EXTENSION_SCRIPT:-}" ] && . ${EXTENSION_SCRIPT}
if [ "${cmd}" == "dump-config" ]; then
if [ ! -d "/conf" ]; then
echo >&2 "You must mount a folder to /conf so that the configuration file(s) can be dumped to there."
exit 1
fi
check_mounted_folder_writable_with_chown "/conf"
cp --recursive "${NEO4J_HOME}"/conf/* /conf
echo "Config Dumped"
exit 0
fi
# this prints out a command for us to run.
# the command is something like: `java ...[lots of java options]... neo4j.mainClass ...[some neo4j options]...`
# putting debug messages here causes the function to break
function get_neo4j_run_cmd {
local extra_args=()
if [ "${EXTENDED_CONF+"yes"}" == "yes" ]; then
extra_args+=("--expand-commands")
fi
if debugging_enabled ; then
extra_args+=("--verbose")
fi
if running_as_root; then
su-exec neo4j:neo4j neo4j console --dry-run "${extra_args[@]}"
else
neo4j console --dry-run "${extra_args[@]}"
fi
}
# Use su-exec to drop privileges to neo4j user
# Note that su-exec, despite its name, does not replicate the
# functionality of exec, so we need to use both
if [ "${cmd}" == "neo4j" ]; then
# separate declaration and use of get_neo4j_run_cmd so that error codes are correctly surfaced
debug_msg "getting full neo4j run command"
neo4j_console_cmd="$(get_neo4j_run_cmd)"
debug_msg "${exec_cmd} ${neo4j_console_cmd}"
eval ${exec_cmd} ${neo4j_console_cmd?:No Neo4j command was generated}
else
debug_msg "${exec_cmd}" "$@"
${exec_cmd} "$@"
fi
```
--------------------------------------------------------------------------------
/neo4j/docker-neo4j/src/test/java/com/neo4j/docker/utils/TemporaryFolderManagerTest.java:
--------------------------------------------------------------------------------
```java
package com.neo4j.docker.utils;
import org.apache.commons.compress.archivers.ArchiveEntry;
import org.apache.commons.compress.archivers.ArchiveInputStream;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.shaded.org.apache.commons.io.IOUtils;
import org.testcontainers.utility.DockerImageName;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
// This is a test for a test utility. It does not actually test anything to do with the docker image.
// This is disabled unless we're actually trying to develop/fix the TemporaryFolderManager utility.
@Disabled
class TemporaryFolderManagerTest
{
@Order( 0 )
@TempDir
static Path outFolder;
@Order( 1 )
@RegisterExtension
public TemporaryFolderManager manager = new TemporaryFolderManager(outFolder);
@AfterEach
void clearAnyCleanupFlags()
{
// some tests may leave folders remaining that are flagged for cleanup, which can affect the
// tests that check that folders are flagged for cleanup. This will reset the flags after each test.
manager.toCompressAfterAll.clear();
}
// TEST AUTOGENERATES SENSIBLE FOLDER NAME FOR EACH UNIT TEST METHOD
@Test
void shouldDeriveFolderNameFromTestMethodName()
{
String expectedMethodNameFolderPrefix = this.getClass().getName() + "_shouldDeriveFolderNameFromTestMethodName";
String actualMethodFolderName = manager.methodOutputFolder.getFileName().toString();
// should generate folder with similar/the same name as the method's reference and add 4 random digits to the end
Assertions.assertTrue( actualMethodFolderName.startsWith( expectedMethodNameFolderPrefix ),
"Did not generate correct temporary folder name from unit test method");
// verify salt is added to foldername like <NAME>_1234
Assertions.assertEquals( expectedMethodNameFolderPrefix.length()+5, actualMethodFolderName.length(),
"Did not correctly add 4 random digits to the expected folder name");
String salt = actualMethodFolderName.substring( expectedMethodNameFolderPrefix.length() + 1 );
Assertions.assertDoesNotThrow( ()->Integer.parseInt( salt ),
"Folder name salt was not digits. Actual: " + actualMethodFolderName );
// folder should not exist until we call a createTempFolder* method
Assertions.assertFalse( manager.methodOutputFolder.toFile().exists(),
"Unit test method folder was created before requesting any folder creation.");
}
@ParameterizedTest
@ValueSource(ints = {4,5,6})
void parameterisedTestMethodsCreateDifferentFolders_unnamedInt(int parameter) throws IOException
{
String expectedMethodNameFolderRegex = this.getClass().getName() +
"_parameterisedTestMethodsCreateDifferentFolders_unnamedInt" +
"_\\[" + (parameter-3) + "\\]_" + parameter + "_\\d{4}";
verifyParameterisedFolderNaming(expectedMethodNameFolderRegex);
}
@ParameterizedTest(name = "name={0}")
@ValueSource(ints = {7,8,9})
void parameterisedTestMethodsCreateDifferentFolders_namedInt(int parameter) throws IOException
{
String expectedMethodNameFolderRegex = this.getClass().getName() +
"_parameterisedTestMethodsCreateDifferentFolders_namedInt_name="
+ parameter + "_\\d{4}";
verifyParameterisedFolderNaming(expectedMethodNameFolderRegex);
}
@ParameterizedTest(name = "name={0}")
@ValueSource(booleans = {true, false})
void parameterisedTestMethodsCreateDifferentFolders_namedBoolean(boolean parameter) throws IOException
{
String expectedMethodNameFolderRegex = this.getClass().getName() +
"_parameterisedTestMethodsCreateDifferentFolders" +
"_namedBoolean_name=" + parameter + "_\\d{4}";
verifyParameterisedFolderNaming(expectedMethodNameFolderRegex);
}
@ParameterizedTest( name = "bool1={0} bool2={1}" )
@CsvSource({"true,true", "true,false", "false,true", "false,false"})
void parameterisedTestMethodsCreateDifferentFolders_twoNamedBooleans(boolean parameter1, boolean parameter2) throws IOException
{
String expectedMethodNameFolderRegex = this.getClass().getName() +
"_parameterisedTestMethodsCreateDifferentFolders" +
"_twoNamedBooleans_bool1=" + parameter1 +
"_bool2=" + parameter2 + "_\\d{4}";
verifyParameterisedFolderNaming(expectedMethodNameFolderRegex);
}
private void verifyParameterisedFolderNaming(String expectedMethodNameFolderRegex) throws IOException
{
// get methodFolderName from TemporaryFolderManager
String actualMethodFolderName = manager.methodOutputFolder.getFileName().toString();
Assertions.assertTrue( Pattern.matches( expectedMethodNameFolderRegex, actualMethodFolderName ),
"Folder \"" + actualMethodFolderName +
"\" does not match expected regex \"" + expectedMethodNameFolderRegex + "\"");
// folder should not yet exist
Path expectedUnitTestMethodFolder = outFolder.resolve( manager.methodOutputFolder );
Assertions.assertFalse( expectedUnitTestMethodFolder.toFile().exists(),
"Folder "+expectedUnitTestMethodFolder + " should not have been created" );
// now create folder
manager.createFolder( "somename" );
Assertions.assertTrue( expectedUnitTestMethodFolder.toFile().exists(),
"Folder "+expectedUnitTestMethodFolder + " should have been created" );
}
@ParameterizedTest
@CsvSource({"/conf,conf", "/data,data", "/import,import", "/logs,logs", "/metrics,metrics", "/plugins,plugins",
"/run/something,run_something", "/place/with space,place_with_space", "/with space,with_space"})
void autoGeneratesSensibleFolderNameFromMountPoint(String mountPoint, String expectedFolderName)
{
Assertions.assertEquals( expectedFolderName, manager.getFolderNameFromMountPoint( mountPoint),
"Did not autogenerate expected name from given mount point");
}
// TEST ACTUAL FOLDER CREATION AND MOUNTING
@Test
void shouldMountAnyFolderToContainer(@TempDir Path tempFolder) throws Exception
{
try(GenericContainer container = makeContainer())
{
manager.mountHostFolderAsVolume( container, tempFolder, "/root" );
container.start();
container.execInContainer( "touch", "/root/testout" );
String files = container.execInContainer( "ls", "/root" ).getStdout();
// sanity check that /root/testout actually was created
Assertions.assertTrue( files.contains( "testout" ),
"did not manage to create file inside container in the mounted folder." );
}
Assertions.assertTrue( tempFolder.resolve( "testout" ).toFile().exists(),
"Test file was created in container but not in mounted folder. " +
"Probably it was unsuccessfully mounted" );
}
@Test
void createsFolder() throws Exception
{
String expectedMethodNameFolderRegex = this.getClass().getName() + "_createsFolder_\\d{4}";
String folderName = "somefolder";
// first verify that no folder exists until we create something
List<Path> allFolders = Files.list( outFolder )
.filter( path -> path.getFileName()
.toString()
.matches( expectedMethodNameFolderRegex ))
.toList();
Assertions.assertEquals( 0, allFolders.size(), "A folder matching " + expectedMethodNameFolderRegex +
" was created when it should not have been");
// now create a folder
Path p = manager.createFolder( folderName );
// verify folder exists, and is located at outFolder > METHODNAME > somefolder
Path methodNameFolder = verifyMethodNameFolderExistsAndIsUnique( expectedMethodNameFolderRegex );
verifyTempFolder( p, folderName, methodNameFolder );
}
@Test
void createsFolderUnderGivenParent() throws Exception
{
String expectedMethodNameFolderRegex = this.getClass().getName() + "_createsFolderUnderGivenParent_\\d{4}";
Path unusedFolder = manager.createFolder( "somefolder1" );
Path expectedParent = manager.createFolder( "somefolder2" );
Path p = manager.createFolder( "somefolder3", expectedParent);
Path methodNameFolder = verifyMethodNameFolderExistsAndIsUnique( expectedMethodNameFolderRegex );
verifyTempFolder( unusedFolder, "somefolder1", methodNameFolder );
verifyTempFolder( expectedParent, "somefolder2", methodNameFolder );
verifyTempFolder( p, "somefolder3", expectedParent );
// should NOT have created something under unusedFolder
List<Path> f = Files.list( unusedFolder ).toList();
Assertions.assertEquals( 0, f.size(),
"Folder should not have been created under "+unusedFolder );
}
@Test
void doesNotCreateFolderOutsideRoot(@TempDir Path nonRootFolder)
{
Assertions.assertThrows( IOException.class,
() -> manager.createFolder( "somefolder", nonRootFolder),
"Created a test folder outside the expected area");
}
@Test
void createNamedFolderAndMount() throws Exception
{
String expectedMethodNameFolderRegex = this.getClass().getName() + "_createNamedFolderAndMount_\\d{4}";
String expectedFolderName = "aFolder";
Path actualTempFolder;
try(GenericContainer container = makeContainer())
{
actualTempFolder = manager.createNamedFolderAndMountAsVolume( container, expectedFolderName, "/root" );
container.start();
container.execInContainer( "touch", "/root/testout" );
String files = container.execInContainer( "ls", "/root" ).getStdout();
// sanity check that /root/testout actually was created
Assertions.assertTrue( files.contains( "testout" ),
"did not manage to create file inside container in the mounted folder." );
}
Path methodFolder = verifyMethodNameFolderExistsAndIsUnique( expectedMethodNameFolderRegex );
Path expectedTempFolder = methodFolder.resolve( expectedFolderName );
verifyTempFolder( expectedTempFolder, expectedFolderName, methodFolder );
Assertions.assertEquals( expectedTempFolder, actualTempFolder,
"Temporary folder was not created in the expected location");
Assertions.assertTrue( expectedTempFolder.resolve( "testout" ).toFile().exists(),
"Test file was created in container but not in mounted folder. " +
"Probably it was unsuccessfully mounted" );
}
@Test
void createAutomaticallyNamedFolderAndMount() throws Exception
{
String expectedMethodNameFolderRegex = this.getClass().getName() + "_createAutomaticallyNamedFolderAndMount_\\d{4}";
String expectedFolderName = "root";
Path actualTempFolder;
try(GenericContainer container = makeContainer())
{
actualTempFolder = manager.createFolderAndMountAsVolume( container, "/root" );
container.start();
container.execInContainer( "touch", "/root/testout" );
String files = container.execInContainer( "ls", "/root" ).getStdout();
// sanity check that /root/testout actually was created
Assertions.assertTrue( files.contains( "testout" ),
"did not manage to create file inside container in the mounted folder." );
}
Path methodFolder = verifyMethodNameFolderExistsAndIsUnique( expectedMethodNameFolderRegex );
Path expectedTempFolder = methodFolder.resolve( expectedFolderName );
verifyTempFolder( expectedTempFolder, expectedFolderName, methodFolder );
Assertions.assertEquals( expectedTempFolder, actualTempFolder,
"Temporary folder was not created in the expected location");
Assertions.assertTrue( expectedTempFolder.resolve( "testout" ).toFile().exists(),
"Test file was created in container but not in mounted folder. " +
"Probably it was unsuccessfully mounted" );
}
private Path verifyMethodNameFolderExistsAndIsUnique(String expectedMethodNameFolderRegex) throws Exception
{
// get methodFolderName from TemporaryFolderManager
String actualMethodFolderName = manager.methodOutputFolder.getFileName().toString();
Assertions.assertTrue( Pattern.matches( expectedMethodNameFolderRegex, actualMethodFolderName ),
"Folder \"" + manager.methodOutputFolder +
"\" does not match expected regex \"" + expectedMethodNameFolderRegex + "\"");
// verify <METHODNAME> folder was created under the root folder store.
List<Path> methodNameFolders = Files.list( outFolder )
.filter( path -> path.getFileName()
.toString()
.matches( expectedMethodNameFolderRegex ) )
.toList();
Assertions.assertEquals( 1, methodNameFolders.size(), "Expected only one folder called " +
expectedMethodNameFolderRegex + ". Actual: " +
methodNameFolders.stream()
.map(Path::toString)
.collect( Collectors.joining( ",")));
Path methodFolder = methodNameFolders.get( 0 ); // previous assertion guarantees this to work
Assertions.assertEquals( methodFolder, manager.methodOutputFolder,
"Folder found in TestTemp is not the same as the one in the folder manager" );
// make sure the <METHODNAME> folder is marked for cleanup
Assertions.assertTrue( manager.toCompressAfterAll.contains( methodFolder ),
"Did not flag " + methodFolder.getFileName() + " for cleanup. Flagged files are: " +
manager.toCompressAfterAll.stream()
.map(Path::toString)
.collect( Collectors.joining( ",")));
return methodFolder;
}
private void verifyTempFolder(Path tempFolder, String expectedFolderName, Path expectedParent)
{
Assertions.assertTrue( tempFolder.toFile().exists(), "createTempFolder did not create anything" );
Assertions.assertTrue( tempFolder.toFile().isDirectory(), "Did not create a directory" );
Assertions.assertEquals(expectedFolderName, tempFolder.toFile().getName(),
"Did not give temp directory the expected name" );
Assertions.assertTrue( tempFolder.getParent().equals( expectedParent ),
"Did not create temp folder under expected parent location. Actual: "+tempFolder.getParent() );
}
// TEST FOLDER IS CLEANED UP
private File verifyTarIsCreatedAndUnique(String expectedTarRegex) throws Exception
{
// verify outFolder contains ONE tar matching our regex
List<Path> tarredFiles = Files.list( outFolder )
.filter( path -> path.getFileName()
.toString()
.matches( expectedTarRegex ) )
.toList();
Assertions.assertEquals( 1, tarredFiles.size(), "Expected only one folder called " +
expectedTarRegex + ". Actual: " +
tarredFiles.stream()
.map(Path::toString)
.collect( Collectors.joining( ",")));
return tarredFiles.get( 0 ).toFile();
}
@Test
void createsTarOfFolder() throws Exception
{
Assumptions.assumeFalse( TestSettings.SKIP_MOUNTED_FOLDER_TARBALLING, "Temporary folder zipping disabled" );
String expectedTarRegex = this.getClass().getName() + "_createsTarOfFolder_\\d{4}\\.tar\\.gz";
String expectedFileContents = "words words words";
// create one folder with one file to be zipped.
Path tempFolder = manager.createFolder( "tozip" );
Files.writeString( tempFolder.resolve( "testfile" ), expectedFileContents );
Assertions.assertTrue( tempFolder.resolve( "testfile" ).toFile().exists(),
"Test failure. Did not successfully write to "+tempFolder);
manager.triggerCleanup();
File actualTar = verifyTarIsCreatedAndUnique( expectedTarRegex );
List<String> files = listFilesInTar( actualTar );
Assertions.assertEquals( 1, files.size(),
"Tar file "+actualTar+" exists but is empty." );
String writtenFile = readFileInTar( actualTar, "/tozip/testfile" );
Assertions.assertEquals( expectedFileContents, writtenFile );
// all temporary folder should now be deleted
Assertions.assertFalse( tempFolder.toFile().exists(), "Temporary folder should have been deleted" );
}
@Test
void createsTarOfFolder_2Files() throws Exception
{
Assumptions.assumeFalse( TestSettings.SKIP_MOUNTED_FOLDER_TARBALLING, "Temporary folder zipping disabled" );
String expectedTarRegex = this.getClass().getName() + "_createsTarOfFolder_2Files_\\d{4}\\.tar\\.gz";
String expectedFileContents1 = "words1 words1 words1";
String expectedFileContents2 = "words2 words2 words2";
// create one folder with one file to be zipped.
Path tempFolder = manager.createFolder( "tozip" );
Files.writeString( tempFolder.resolve( "testfile1" ), expectedFileContents1 );
Assertions.assertTrue( tempFolder.resolve( "testfile1" ).toFile().exists(),
"Test failure. Did not successfully write to "+tempFolder);
Files.writeString( tempFolder.resolve( "testfile2" ), expectedFileContents2 );
Assertions.assertTrue( tempFolder.resolve( "testfile2" ).toFile().exists(),
"Test failure. Did not successfully write to "+tempFolder);
manager.triggerCleanup();
File actualTar = verifyTarIsCreatedAndUnique( expectedTarRegex );
List<String> files = listFilesInTar( actualTar );
Assertions.assertEquals( 2, files.size(),
"Tar file "+actualTar+" exists but is empty." );
String writtenFile1 = readFileInTar( actualTar, "/tozip/testfile1" );
String writtenFile2 = readFileInTar( actualTar, "/tozip/testfile2" );
Assertions.assertEquals( expectedFileContents1, writtenFile1 );
Assertions.assertEquals( expectedFileContents2, writtenFile2 );
Assertions.assertFalse( tempFolder.toFile().exists(), "Temporary folder should have been deleted" );
}
@Test
void createsTarOfFolder_2Folders() throws Exception
{
Assumptions.assumeFalse( TestSettings.SKIP_MOUNTED_FOLDER_TARBALLING, "Temporary folder zipping disabled" );
String expectedTarRegex = this.getClass().getName() + "_createsTarOfFolder_2Folders_\\d{4}\\.tar\\.gz";
String expectedFileContents1 = "words1 words1 words1";
String expectedFileContents2 = "words2 words2 words2";
// create one folder with one file to be zipped.
Path tempFolder1 = manager.createFolder( "tozip1" );
Files.writeString( tempFolder1.resolve( "testfile" ), expectedFileContents1 );
Assertions.assertTrue( tempFolder1.resolve( "testfile" ).toFile().exists(),
"Test failure. Did not successfully write to "+tempFolder1);
Path tempFolder2 = manager.createFolder( "tozip2" );
Files.writeString( tempFolder2.resolve( "testfile" ), expectedFileContents2 );
Assertions.assertTrue( tempFolder2.resolve( "testfile" ).toFile().exists(),
"Test failure. Did not successfully write to "+tempFolder2);
manager.triggerCleanup();
File actualTar = verifyTarIsCreatedAndUnique( expectedTarRegex );
List<String> files = listFilesInTar( actualTar );
Assertions.assertEquals( 2, files.size(),
"Tar file "+actualTar+" exists but does not contain the expected files." );
String writtenFile1 = readFileInTar( actualTar, "/tozip1/testfile" );
Assertions.assertEquals( expectedFileContents1, writtenFile1 );
String writtenFile2 = readFileInTar( actualTar, "/tozip2/testfile" );
Assertions.assertEquals( expectedFileContents2, writtenFile2 );
Assertions.assertFalse( tempFolder1.toFile().exists(), "Temporary folder should have been deleted" );
Assertions.assertFalse( tempFolder2.toFile().exists(), "Temporary folder should have been deleted" );
}
@Test
void createsTarOfFolder_nestedFolders() throws Exception
{
Assumptions.assumeFalse( TestSettings.SKIP_MOUNTED_FOLDER_TARBALLING, "Temporary folder zipping disabled" );
String expectedTarRegex = this.getClass().getName() + "_createsTarOfFolder_nestedFolders_\\d{4}\\.tar\\.gz";
// creating folders:
// tempFolder1
// | tempFolder2
// | | testfile
String expectedFileContents = "words words words";
// create one folder with one file to be zipped.
Path tempFolder1 = manager.createFolder( "tozip1" );
Path tempFolder2 = manager.createFolder( "tozip2", tempFolder1 );
Files.writeString( tempFolder2.resolve( "testfile" ), expectedFileContents );
Assertions.assertTrue( tempFolder2.resolve( "testfile" ).toFile().exists(),
"Test failure. Did not successfully write to "+tempFolder2);
manager.triggerCleanup();
File actualTar = verifyTarIsCreatedAndUnique( expectedTarRegex );
List<String> files = listFilesInTar( actualTar );
Assertions.assertEquals( 1, files.size(),
"Tar file "+actualTar+" exists but is empty." );
String writtenFile = readFileInTar( actualTar,"/tozip1/tozip2/testfile" );
Assertions.assertEquals( expectedFileContents, writtenFile );
Assertions.assertFalse( tempFolder1.toFile().exists(), "Temporary folder should have been deleted" );
}
// TEST CODE CLEANUP WITH REOWNING
@Test
void canSetFolderOwnerTo7474ThenCleanup() throws Exception
{
Assumptions.assumeFalse( TestSettings.SKIP_MOUNTED_FOLDER_TARBALLING, "Temporary folder zipping disabled" );
String expectedTarRegex = this.getClass().getName() + "_canSetFolderOwnerTo7474ThenCleanup_\\d{4}\\.tar\\.gz";
Path tempFolder = manager.createFolder( "tozip" );
Files.writeString ( tempFolder.resolve( "testfile" ), "words" );
manager.setFolderOwnerToNeo4j( tempFolder );
// verify expected folder owner
Integer fileUID = (Integer) Files.getAttribute( tempFolder, "unix:uid" );
Assertions.assertEquals( 7474, fileUID.intValue(),
"Did not successfully set the owner of "+tempFolder );
// clean up and verify successfully cleaned up
manager.triggerCleanup();
verifyTarIsCreatedAndUnique( expectedTarRegex );
Assertions.assertFalse( tempFolder.toFile().exists(), "Did not successfully delete "+tempFolder );
}
@Test
void canCreateAndCleanupFoldersWithDifferentOwners() throws Exception
{
Assumptions.assumeFalse( TestSettings.SKIP_MOUNTED_FOLDER_TARBALLING, "Temporary folder zipping disabled" );
String expectedTarRegex = this.getClass().getName() + "_canCreateAndCleanupFoldersWithDifferentOwners_\\d{4}\\.tar\\.gz";
Path tempFolder7474 = manager.createFolder( "tozip7474" );
Path tempFolderNormal = manager.createFolder( "tozipNormal" );
Files.writeString( tempFolder7474.resolve( "testfile" ), "words" );
Files.writeString( tempFolderNormal.resolve( "testfile" ), "words" );
manager.setFolderOwnerToNeo4j( tempFolder7474 );
Integer fileUID = (Integer) Files.getAttribute( tempFolder7474, "unix:uid" );
Assertions.assertEquals( 7474, fileUID.intValue(),
"Did not successfully set the owner of "+tempFolder7474 );
manager.triggerCleanup();
verifyTarIsCreatedAndUnique( expectedTarRegex );
Assertions.assertFalse( tempFolderNormal.toFile().exists(), "Did not successfully delete "+tempFolderNormal );
Assertions.assertFalse( tempFolder7474.toFile().exists(), "Did not successfully delete "+tempFolder7474 );
}
private GenericContainer makeContainer()
{
// we don't want to test the neo4j container, just use a generic container debian to check mounting.
// using nginx here just because there is a straightforward way of waiting for it to be ready
GenericContainer container = new GenericContainer(DockerImageName.parse("nginx:latest"))
.withExposedPorts(80)
.waitingFor(Wait.forHttp("/").withStartupTimeout( Duration.ofSeconds( 5 ) ));
return container;
}
private List<String> listFilesInTar(File tar) throws IOException
{
List<String> files = new ArrayList<>();
ArchiveInputStream in = new TarArchiveInputStream(
new GzipCompressorInputStream( new FileInputStream(tar) ));
ArchiveEntry entry = in.getNextEntry();
while(entry != null)
{
files.add( entry.getName() );
entry = in.getNextEntry();
}
in.close();
return files;
}
private String readFileInTar(File tar, String internalFilePath) throws IOException
{
internalFilePath = tar.getName().split( "\\.tar\\.gz" )[0] + internalFilePath;
String fileContents = null;
ArchiveInputStream in = new TarArchiveInputStream(
new GzipCompressorInputStream( new FileInputStream(tar) ));
ArchiveEntry entry = in.getNextEntry();
while(entry != null)
{
if(entry.getName().equals( internalFilePath ))
{
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
IOUtils.copy(in, outStream);
fileContents = outStream.toString();
break;
}
entry = in.getNextEntry();
}
in.close();
Assertions.assertNotNull( fileContents, "Could not extract file "+internalFilePath+" from "+tar);
return fileContents;
}
}
```
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
```python
"""
Utility functions for the Crawl4AI MCP server.
"""
import os
import concurrent.futures
from typing import List, Dict, Any, Optional, Tuple
import json
from supabase import create_client, Client
from urllib.parse import urlparse
import openai
import re
import time
# Load OpenAI API key for embeddings
openai.api_key = os.getenv("OPENAI_API_KEY")
# Initialize OpenRouter client for chat completions
def get_openrouter_client():
"""
Get an OpenRouter-configured OpenAI client for chat completions.
Returns:
OpenAI client configured for OpenRouter
"""
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
if not openrouter_api_key:
raise ValueError("OPENROUTER_API_KEY must be set in environment variables for chat completions")
return openai.OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=openrouter_api_key,
default_headers={
"HTTP-Referer": os.getenv("YOUR_SITE_URL", ""), # Optional for rankings
"X-Title": os.getenv("YOUR_SITE_NAME", "Crawl4AI-MCP"), # Optional for rankings
}
)
def get_supabase_client() -> Client:
"""
Get a Supabase client with the URL and key from environment variables.
Returns:
Supabase client instance
"""
url = os.getenv("SUPABASE_URL")
key = os.getenv("SUPABASE_SERVICE_KEY")
if not url or not key:
raise ValueError("SUPABASE_URL and SUPABASE_SERVICE_KEY must be set in environment variables")
return create_client(url, key)
def create_embeddings_batch(texts: List[str]) -> List[List[float]]:
"""
Create embeddings for multiple texts in a single API call.
Args:
texts: List of texts to create embeddings for
Returns:
List of embeddings (each embedding is a list of floats)
"""
if not texts:
return []
max_retries = 3
retry_delay = 1.0 # Start with 1 second delay
for retry in range(max_retries):
try:
response = openai.embeddings.create(
model="text-embedding-3-small", # Hardcoding embedding model for now, will change this later to be more dynamic
input=texts
)
return [item.embedding for item in response.data]
except Exception as e:
if retry < max_retries - 1:
print(f"Error creating batch embeddings (attempt {retry + 1}/{max_retries}): {e}")
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
print(f"Failed to create batch embeddings after {max_retries} attempts: {e}")
# Try creating embeddings one by one as fallback
print("Attempting to create embeddings individually...")
embeddings = []
successful_count = 0
for i, text in enumerate(texts):
try:
individual_response = openai.embeddings.create(
model="text-embedding-3-small",
input=[text]
)
embeddings.append(individual_response.data[0].embedding)
successful_count += 1
except Exception as individual_error:
print(f"Failed to create embedding for text {i}: {individual_error}")
# Add zero embedding as fallback
embeddings.append([0.0] * 1536)
print(f"Successfully created {successful_count}/{len(texts)} embeddings individually")
return embeddings
def create_embedding(text: str) -> List[float]:
"""
Create an embedding for a single text using OpenAI's API.
Args:
text: Text to create an embedding for
Returns:
List of floats representing the embedding
"""
try:
embeddings = create_embeddings_batch([text])
return embeddings[0] if embeddings else [0.0] * 1536
except Exception as e:
print(f"Error creating embedding: {e}")
# Return empty embedding if there's an error
return [0.0] * 1536
def generate_contextual_embedding(full_document: str, chunk: str) -> Tuple[str, bool]:
"""
Generate contextual information for a chunk within a document to improve retrieval.
Args:
full_document: The complete document text
chunk: The specific chunk of text to generate context for
Returns:
Tuple containing:
- The contextual text that situates the chunk within the document
- Boolean indicating if contextual embedding was performed
"""
model_choice = os.getenv("MODEL_CHOICE", "openai/gpt-4.1-nano")
try:
# Get OpenRouter client for chat completions
openrouter_client = get_openrouter_client()
# Create the prompt for generating contextual information
prompt = f"""<document>
{full_document[:25000]}
</document>
Here is the chunk we want to situate within the whole document
<chunk>
{chunk}
</chunk>
Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else."""
# Call the OpenRouter API to generate contextual information
response = openrouter_client.chat.completions.create(
model=model_choice,
messages=[
{"role": "system", "content": "You are a helpful assistant that provides concise contextual information."},
{"role": "user", "content": prompt}
],
temperature=0.3,
max_tokens=200
)
# Extract the generated context
context = response.choices[0].message.content.strip()
# Combine the context with the original chunk
contextual_text = f"{context}\n---\n{chunk}"
return contextual_text, True
except Exception as e:
print(f"Error generating contextual embedding: {e}. Using original chunk instead.")
return chunk, False
def process_chunk_with_context(args):
"""
Process a single chunk with contextual embedding.
This function is designed to be used with concurrent.futures.
Args:
args: Tuple containing (url, content, full_document)
Returns:
Tuple containing:
- The contextual text that situates the chunk within the document
- Boolean indicating if contextual embedding was performed
"""
url, content, full_document = args
return generate_contextual_embedding(full_document, content)
def add_documents_to_supabase(
client: Client,
urls: List[str],
chunk_numbers: List[int],
contents: List[str],
metadatas: List[Dict[str, Any]],
url_to_full_document: Dict[str, str],
batch_size: int = 20
) -> None:
"""
Add documents to the Supabase crawled_pages table in batches.
Deletes existing records with the same URLs before inserting to prevent duplicates.
Args:
client: Supabase client
urls: List of URLs
chunk_numbers: List of chunk numbers
contents: List of document contents
metadatas: List of document metadata
url_to_full_document: Dictionary mapping URLs to their full document content
batch_size: Size of each batch for insertion
"""
# Get unique URLs to delete existing records
unique_urls = list(set(urls))
# Delete existing records for these URLs in a single operation
try:
if unique_urls:
# Use the .in_() filter to delete all records with matching URLs
client.table("crawled_pages").delete().in_("url", unique_urls).execute()
except Exception as e:
print(f"Batch delete failed: {e}. Trying one-by-one deletion as fallback.")
# Fallback: delete records one by one
for url in unique_urls:
try:
client.table("crawled_pages").delete().eq("url", url).execute()
except Exception as inner_e:
print(f"Error deleting record for URL {url}: {inner_e}")
# Continue with the next URL even if one fails
# Check if MODEL_CHOICE is set for contextual embeddings
use_contextual_embeddings = os.getenv("USE_CONTEXTUAL_EMBEDDINGS", "false") == "true"
print(f"\n\nUse contextual embeddings: {use_contextual_embeddings}\n\n")
# Process in batches to avoid memory issues
for i in range(0, len(contents), batch_size):
batch_end = min(i + batch_size, len(contents))
# Get batch slices
batch_urls = urls[i:batch_end]
batch_chunk_numbers = chunk_numbers[i:batch_end]
batch_contents = contents[i:batch_end]
batch_metadatas = metadatas[i:batch_end]
# Apply contextual embedding to each chunk if MODEL_CHOICE is set
if use_contextual_embeddings:
# Prepare arguments for parallel processing
process_args = []
for j, content in enumerate(batch_contents):
url = batch_urls[j]
full_document = url_to_full_document.get(url, "")
process_args.append((url, content, full_document))
# Process in parallel using ThreadPoolExecutor
contextual_contents = []
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
# Submit all tasks and collect results
future_to_idx = {executor.submit(process_chunk_with_context, arg): idx
for idx, arg in enumerate(process_args)}
# Process results as they complete
for future in concurrent.futures.as_completed(future_to_idx):
idx = future_to_idx[future]
try:
result, success = future.result()
contextual_contents.append(result)
if success:
batch_metadatas[idx]["contextual_embedding"] = True
except Exception as e:
print(f"Error processing chunk {idx}: {e}")
# Use original content as fallback
contextual_contents.append(batch_contents[idx])
# Sort results back into original order if needed
if len(contextual_contents) != len(batch_contents):
print(f"Warning: Expected {len(batch_contents)} results but got {len(contextual_contents)}")
# Use original contents as fallback
contextual_contents = batch_contents
else:
# If not using contextual embeddings, use original contents
contextual_contents = batch_contents
# Create embeddings for the entire batch at once
batch_embeddings = create_embeddings_batch(contextual_contents)
batch_data = []
for j in range(len(contextual_contents)):
# Extract metadata fields
chunk_size = len(contextual_contents[j])
# Extract source_id from URL
parsed_url = urlparse(batch_urls[j])
source_id = parsed_url.netloc or parsed_url.path
# Prepare data for insertion
data = {
"url": batch_urls[j],
"chunk_number": batch_chunk_numbers[j],
"content": contextual_contents[j], # Store original content
"metadata": {
"chunk_size": chunk_size,
**batch_metadatas[j]
},
"source_id": source_id, # Add source_id field
"embedding": batch_embeddings[j] # Use embedding from contextual content
}
batch_data.append(data)
# Insert batch into Supabase with retry logic
max_retries = 3
retry_delay = 1.0 # Start with 1 second delay
for retry in range(max_retries):
try:
client.table("crawled_pages").insert(batch_data).execute()
# Success - break out of retry loop
break
except Exception as e:
if retry < max_retries - 1:
print(f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}")
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Final attempt failed
print(f"Failed to insert batch after {max_retries} attempts: {e}")
# Optionally, try inserting records one by one as a last resort
print("Attempting to insert records individually...")
successful_inserts = 0
for record in batch_data:
try:
client.table("crawled_pages").insert(record).execute()
successful_inserts += 1
except Exception as individual_error:
print(f"Failed to insert individual record for URL {record['url']}: {individual_error}")
if successful_inserts > 0:
print(f"Successfully inserted {successful_inserts}/{len(batch_data)} records individually")
def search_documents(
client: Client,
query: str,
match_count: int = 10,
filter_metadata: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Search for documents in Supabase using vector similarity.
Args:
client: Supabase client
query: Query text
match_count: Maximum number of results to return
filter_metadata: Optional metadata filter
Returns:
List of matching documents
"""
# Create embedding for the query
query_embedding = create_embedding(query)
# Execute the search using the match_crawled_pages function
try:
# Only include filter parameter if filter_metadata is provided and not empty
params = {
'query_embedding': query_embedding,
'match_count': match_count
}
# Only add the filter if it's actually provided and not empty
if filter_metadata:
params['filter'] = filter_metadata # Pass the dictionary directly, not JSON-encoded
result = client.rpc('match_crawled_pages', params).execute()
return result.data
except Exception as e:
print(f"Error searching documents: {e}")
return []
def extract_code_blocks(markdown_content: str, min_length: int = 1000) -> List[Dict[str, Any]]:
"""
Extract code blocks from markdown content along with context.
Args:
markdown_content: The markdown content to extract code blocks from
min_length: Minimum length of code blocks to extract (default: 1000 characters)
Returns:
List of dictionaries containing code blocks and their context
"""
code_blocks = []
# Skip if content starts with triple backticks (edge case for files wrapped in backticks)
content = markdown_content.strip()
start_offset = 0
if content.startswith('```'):
# Skip the first triple backticks
start_offset = 3
print("Skipping initial triple backticks")
# Find all occurrences of triple backticks
backtick_positions = []
pos = start_offset
while True:
pos = markdown_content.find('```', pos)
if pos == -1:
break
backtick_positions.append(pos)
pos += 3
# Process pairs of backticks
i = 0
while i < len(backtick_positions) - 1:
start_pos = backtick_positions[i]
end_pos = backtick_positions[i + 1]
# Extract the content between backticks
code_section = markdown_content[start_pos+3:end_pos]
# Check if there's a language specifier on the first line
lines = code_section.split('\n', 1)
if len(lines) > 1:
# Check if first line is a language specifier (no spaces, common language names)
first_line = lines[0].strip()
if first_line and not ' ' in first_line and len(first_line) < 20:
language = first_line
code_content = lines[1].strip() if len(lines) > 1 else ""
else:
language = ""
code_content = code_section.strip()
else:
language = ""
code_content = code_section.strip()
# Skip if code block is too short
if len(code_content) < min_length:
i += 2 # Move to next pair
continue
# Extract context before (1000 chars)
context_start = max(0, start_pos - 1000)
context_before = markdown_content[context_start:start_pos].strip()
# Extract context after (1000 chars)
context_end = min(len(markdown_content), end_pos + 3 + 1000)
context_after = markdown_content[end_pos + 3:context_end].strip()
code_blocks.append({
'code': code_content,
'language': language,
'context_before': context_before,
'context_after': context_after,
'full_context': f"{context_before}\n\n{code_content}\n\n{context_after}"
})
# Move to next pair (skip the closing backtick we just processed)
i += 2
return code_blocks
def generate_code_example_summary(code: str, context_before: str, context_after: str) -> str:
"""
Generate a summary for a code example using its surrounding context.
Args:
code: The code example
context_before: Context before the code
context_after: Context after the code
Returns:
A summary of what the code example demonstrates
"""
model_choice = os.getenv("MODEL_CHOICE", "openai/gpt-4.1-nano")
# Create the prompt
prompt = f"""<context_before>
{context_before[-500:] if len(context_before) > 500 else context_before}
</context_before>
<code_example>
{code[:1500] if len(code) > 1500 else code}
</code_example>
<context_after>
{context_after[:500] if len(context_after) > 500 else context_after}
</context_after>
Based on the code example and its surrounding context, provide a concise summary (2-3 sentences) that describes what this code example demonstrates and its purpose. Focus on the practical application and key concepts illustrated.
"""
try:
# Get OpenRouter client for chat completions
openrouter_client = get_openrouter_client()
response = openrouter_client.chat.completions.create(
model=model_choice,
messages=[
{"role": "system", "content": "You are a helpful assistant that provides concise code example summaries."},
{"role": "user", "content": prompt}
],
temperature=0.3,
max_tokens=300
)
return response.choices[0].message.content.strip()
except Exception as e:
print(f"Error generating code example summary: {e}")
return "Code example for demonstration purposes."
def add_code_examples_to_supabase(
client: Client,
urls: List[str],
chunk_numbers: List[int],
code_examples: List[str],
summaries: List[str],
metadatas: List[Dict[str, Any]],
batch_size: int = 20
):
"""
Add code examples to the Supabase code_examples table in batches.
Args:
client: Supabase client
urls: List of URLs
chunk_numbers: List of chunk numbers
code_examples: List of code example contents
summaries: List of code example summaries
metadatas: List of metadata dictionaries
batch_size: Size of each batch for insertion
"""
if not urls:
return
# Delete existing records for these URLs
unique_urls = list(set(urls))
for url in unique_urls:
try:
client.table('code_examples').delete().eq('url', url).execute()
except Exception as e:
print(f"Error deleting existing code examples for {url}: {e}")
# Process in batches
total_items = len(urls)
for i in range(0, total_items, batch_size):
batch_end = min(i + batch_size, total_items)
batch_texts = []
# Create combined texts for embedding (code + summary)
for j in range(i, batch_end):
combined_text = f"{code_examples[j]}\n\nSummary: {summaries[j]}"
batch_texts.append(combined_text)
# Create embeddings for the batch
embeddings = create_embeddings_batch(batch_texts)
# Check if embeddings are valid (not all zeros)
valid_embeddings = []
for embedding in embeddings:
if embedding and not all(v == 0.0 for v in embedding):
valid_embeddings.append(embedding)
else:
print(f"Warning: Zero or invalid embedding detected, creating new one...")
# Try to create a single embedding as fallback
single_embedding = create_embedding(batch_texts[len(valid_embeddings)])
valid_embeddings.append(single_embedding)
# Prepare batch data
batch_data = []
for j, embedding in enumerate(valid_embeddings):
idx = i + j
# Extract source_id from URL
parsed_url = urlparse(urls[idx])
source_id = parsed_url.netloc or parsed_url.path
batch_data.append({
'url': urls[idx],
'chunk_number': chunk_numbers[idx],
'content': code_examples[idx],
'summary': summaries[idx],
'metadata': metadatas[idx], # Store as JSON object, not string
'source_id': source_id,
'embedding': embedding
})
# Insert batch into Supabase with retry logic
max_retries = 3
retry_delay = 1.0 # Start with 1 second delay
for retry in range(max_retries):
try:
client.table('code_examples').insert(batch_data).execute()
# Success - break out of retry loop
break
except Exception as e:
if retry < max_retries - 1:
print(f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}")
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Final attempt failed
print(f"Failed to insert batch after {max_retries} attempts: {e}")
# Optionally, try inserting records one by one as a last resort
print("Attempting to insert records individually...")
successful_inserts = 0
for record in batch_data:
try:
client.table('code_examples').insert(record).execute()
successful_inserts += 1
except Exception as individual_error:
print(f"Failed to insert individual record for URL {record['url']}: {individual_error}")
if successful_inserts > 0:
print(f"Successfully inserted {successful_inserts}/{len(batch_data)} records individually")
print(f"Inserted batch {i//batch_size + 1} of {(total_items + batch_size - 1)//batch_size} code examples")
def update_source_info(client: Client, source_id: str, summary: str, word_count: int):
"""
Update or insert source information in the sources table.
Args:
client: Supabase client
source_id: The source ID (domain)
summary: Summary of the source
word_count: Total word count for the source
"""
try:
# Try to update existing source
result = client.table('sources').update({
'summary': summary,
'total_word_count': word_count,
'updated_at': 'now()'
}).eq('source_id', source_id).execute()
# If no rows were updated, insert new source
if not result.data:
client.table('sources').insert({
'source_id': source_id,
'summary': summary,
'total_word_count': word_count
}).execute()
print(f"Created new source: {source_id}")
else:
print(f"Updated source: {source_id}")
except Exception as e:
print(f"Error updating source {source_id}: {e}")
def extract_source_summary(source_id: str, content: str, max_length: int = 500) -> str:
"""
Extract a summary for a source from its content using an LLM.
This function uses the OpenRouter API to generate a concise summary of the source content.
Args:
source_id: The source ID (domain)
content: The content to extract a summary from
max_length: Maximum length of the summary
Returns:
A summary string
"""
# Default summary if we can't extract anything meaningful
default_summary = f"Content from {source_id}"
if not content or len(content.strip()) == 0:
return default_summary
# Get the model choice from environment variables
model_choice = os.getenv("MODEL_CHOICE", "openai/gpt-4.1-nano")
# Limit content length to avoid token limits
truncated_content = content[:25000] if len(content) > 25000 else content
# Create the prompt for generating the summary
prompt = f"""<source_content>
{truncated_content}
</source_content>
The above content is from the documentation for '{source_id}'. Please provide a concise summary (3-5 sentences) that describes what this library/tool/framework is about. The summary should help understand what the library/tool/framework accomplishes and the purpose.
"""
try:
# Get OpenRouter client for chat completions
openrouter_client = get_openrouter_client()
# Call the OpenRouter API to generate the summary
response = openrouter_client.chat.completions.create(
model=model_choice,
messages=[
{"role": "system", "content": "You are a helpful assistant that provides concise library/tool/framework summaries."},
{"role": "user", "content": prompt}
],
temperature=0.3,
max_tokens=300
)
# Extract the generated summary
summary = response.choices[0].message.content.strip()
# Ensure the summary is not too long
if len(summary) > max_length:
summary = summary[:max_length] + "..."
return summary
except Exception as e:
print(f"Error generating summary with LLM for {source_id}: {e}. Using default summary.")
return default_summary
def search_code_examples(
client: Client,
query: str,
match_count: int = 10,
filter_metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Search for code examples in Supabase using vector similarity.
Args:
client: Supabase client
query: Query text
match_count: Maximum number of results to return
filter_metadata: Optional metadata filter
source_id: Optional source ID to filter results
Returns:
List of matching code examples
"""
# Create a more descriptive query for better embedding match
# Since code examples are embedded with their summaries, we should make the query more descriptive
enhanced_query = f"Code example for {query}\n\nSummary: Example code showing {query}"
# Create embedding for the enhanced query
query_embedding = create_embedding(enhanced_query)
# Execute the search using the match_code_examples function
try:
# Only include filter parameter if filter_metadata is provided and not empty
params = {
'query_embedding': query_embedding,
'match_count': match_count
}
# Only add the filter if it's actually provided and not empty
if filter_metadata:
params['filter'] = filter_metadata
# Add source filter if provided
if source_id:
params['source_filter'] = source_id
result = client.rpc('match_code_examples', params).execute()
return result.data
except Exception as e:
print(f"Error searching code examples: {e}")
return []
```
--------------------------------------------------------------------------------
/knowledge_graphs/parse_repo_into_neo4j.py:
--------------------------------------------------------------------------------
```python
"""
Direct Neo4j GitHub Code Repository Extractor
Creates nodes and relationships directly in Neo4j without Graphiti:
- File nodes
- Class nodes
- Method nodes
- Function nodes
- Import relationships
Bypasses all LLM processing for maximum speed.
"""
import asyncio
import logging
import os
import subprocess
import shutil
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Optional, Dict, Any, Set
import ast
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)
class Neo4jCodeAnalyzer:
"""Analyzes code for direct Neo4j insertion"""
def __init__(self):
# External modules to ignore
self.external_modules = {
# Python standard library
'os', 'sys', 'json', 'logging', 'datetime', 'pathlib', 'typing', 'collections',
'asyncio', 'subprocess', 'ast', 're', 'string', 'urllib', 'http', 'email',
'time', 'uuid', 'hashlib', 'base64', 'itertools', 'functools', 'operator',
'contextlib', 'copy', 'pickle', 'tempfile', 'shutil', 'glob', 'fnmatch',
'io', 'codecs', 'locale', 'platform', 'socket', 'ssl', 'threading', 'queue',
'multiprocessing', 'concurrent', 'warnings', 'traceback', 'inspect',
'importlib', 'pkgutil', 'types', 'weakref', 'gc', 'dataclasses', 'enum',
'abc', 'numbers', 'decimal', 'fractions', 'math', 'cmath', 'random', 'statistics',
# Common third-party libraries
'requests', 'urllib3', 'httpx', 'aiohttp', 'flask', 'django', 'fastapi',
'pydantic', 'sqlalchemy', 'alembic', 'psycopg2', 'pymongo', 'redis',
'celery', 'pytest', 'unittest', 'mock', 'faker', 'factory', 'hypothesis',
'numpy', 'pandas', 'matplotlib', 'seaborn', 'scipy', 'sklearn', 'torch',
'tensorflow', 'keras', 'opencv', 'pillow', 'boto3', 'botocore', 'azure',
'google', 'openai', 'anthropic', 'langchain', 'transformers', 'huggingface_hub',
'click', 'typer', 'rich', 'colorama', 'tqdm', 'python-dotenv', 'pyyaml',
'toml', 'configargparse', 'marshmallow', 'attrs', 'dataclasses-json',
'jsonschema', 'cerberus', 'voluptuous', 'schema', 'jinja2', 'mako',
'cryptography', 'bcrypt', 'passlib', 'jwt', 'authlib', 'oauthlib'
}
def analyze_python_file(self, file_path: Path, repo_root: Path, project_modules: Set[str]) -> Dict[str, Any]:
"""Extract structure for direct Neo4j insertion"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
tree = ast.parse(content)
relative_path = str(file_path.relative_to(repo_root))
module_name = self._get_importable_module_name(file_path, repo_root, relative_path)
# Extract structure
classes = []
functions = []
imports = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
# Extract class with its methods and attributes
methods = []
attributes = []
for item in node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
if not item.name.startswith('_'): # Public methods only
# Extract comprehensive parameter info
params = self._extract_function_parameters(item)
# Get return type annotation
return_type = self._get_name(item.returns) if item.returns else 'Any'
# Create detailed parameter list for Neo4j storage
params_detailed = []
for p in params:
param_str = f"{p['name']}:{p['type']}"
if p['optional'] and p['default'] is not None:
param_str += f"={p['default']}"
elif p['optional']:
param_str += "=None"
if p['kind'] != 'positional':
param_str = f"[{p['kind']}] {param_str}"
params_detailed.append(param_str)
methods.append({
'name': item.name,
'params': params, # Full parameter objects
'params_detailed': params_detailed, # Detailed string format
'return_type': return_type,
'args': [arg.arg for arg in item.args.args if arg.arg != 'self'] # Keep for backwards compatibility
})
elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
# Type annotated attributes
if not item.target.id.startswith('_'):
attributes.append({
'name': item.target.id,
'type': self._get_name(item.annotation) if item.annotation else 'Any'
})
classes.append({
'name': node.name,
'full_name': f"{module_name}.{node.name}",
'methods': methods,
'attributes': attributes
})
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Only top-level functions
if not any(node in cls_node.body for cls_node in ast.walk(tree) if isinstance(cls_node, ast.ClassDef)):
if not node.name.startswith('_'):
# Extract comprehensive parameter info
params = self._extract_function_parameters(node)
# Get return type annotation
return_type = self._get_name(node.returns) if node.returns else 'Any'
# Create detailed parameter list for Neo4j storage
params_detailed = []
for p in params:
param_str = f"{p['name']}:{p['type']}"
if p['optional'] and p['default'] is not None:
param_str += f"={p['default']}"
elif p['optional']:
param_str += "=None"
if p['kind'] != 'positional':
param_str = f"[{p['kind']}] {param_str}"
params_detailed.append(param_str)
# Simple format for backwards compatibility
params_list = [f"{p['name']}:{p['type']}" for p in params]
functions.append({
'name': node.name,
'full_name': f"{module_name}.{node.name}",
'params': params, # Full parameter objects
'params_detailed': params_detailed, # Detailed string format
'params_list': params_list, # Simple string format for backwards compatibility
'return_type': return_type,
'args': [arg.arg for arg in node.args.args] # Keep for backwards compatibility
})
elif isinstance(node, (ast.Import, ast.ImportFrom)):
# Track internal imports only
if isinstance(node, ast.Import):
for alias in node.names:
if self._is_likely_internal(alias.name, project_modules):
imports.append(alias.name)
elif isinstance(node, ast.ImportFrom) and node.module:
if (node.module.startswith('.') or self._is_likely_internal(node.module, project_modules)):
imports.append(node.module)
return {
'module_name': module_name,
'file_path': relative_path,
'classes': classes,
'functions': functions,
'imports': list(set(imports)), # Remove duplicates
'line_count': len(content.splitlines())
}
except Exception as e:
logger.warning(f"Could not analyze {file_path}: {e}")
return None
def _is_likely_internal(self, import_name: str, project_modules: Set[str]) -> bool:
"""Check if an import is likely internal to the project"""
if not import_name:
return False
# Relative imports are definitely internal
if import_name.startswith('.'):
return True
# Check if it's a known external module
base_module = import_name.split('.')[0]
if base_module in self.external_modules:
return False
# Check if it matches any project module
for project_module in project_modules:
if import_name.startswith(project_module):
return True
# If it's not obviously external, consider it internal
if (not any(ext in base_module.lower() for ext in ['test', 'mock', 'fake']) and
not base_module.startswith('_') and
len(base_module) > 2):
return True
return False
def _get_importable_module_name(self, file_path: Path, repo_root: Path, relative_path: str) -> str:
"""Determine the actual importable module name for a Python file"""
# Start with the default: convert file path to module path
default_module = relative_path.replace('/', '.').replace('\\', '.').replace('.py', '')
# Common patterns to detect the actual package root
path_parts = Path(relative_path).parts
# Look for common package indicators
package_roots = []
# Check each directory level for __init__.py to find package boundaries
current_path = repo_root
for i, part in enumerate(path_parts[:-1]): # Exclude the .py file itself
current_path = current_path / part
if (current_path / '__init__.py').exists():
# This is a package directory, mark it as a potential root
package_roots.append(i)
if package_roots:
# Use the first (outermost) package as the root
package_start = package_roots[0]
module_parts = path_parts[package_start:]
module_name = '.'.join(module_parts).replace('.py', '')
return module_name
# Fallback: look for common Python project structures
# Skip common non-package directories
skip_dirs = {'src', 'lib', 'source', 'python', 'pkg', 'packages'}
# Find the first directory that's not in skip_dirs
filtered_parts = []
for part in path_parts:
if part.lower() not in skip_dirs or filtered_parts: # Once we start including, include everything
filtered_parts.append(part)
if filtered_parts:
module_name = '.'.join(filtered_parts).replace('.py', '')
return module_name
# Final fallback: use the default
return default_module
def _extract_function_parameters(self, func_node):
"""Comprehensive parameter extraction from function definition"""
params = []
# Regular positional arguments
for i, arg in enumerate(func_node.args.args):
if arg.arg == 'self':
continue
param_info = {
'name': arg.arg,
'type': self._get_name(arg.annotation) if arg.annotation else 'Any',
'kind': 'positional',
'optional': False,
'default': None
}
# Check if this argument has a default value
defaults_start = len(func_node.args.args) - len(func_node.args.defaults)
if i >= defaults_start:
default_idx = i - defaults_start
if default_idx < len(func_node.args.defaults):
param_info['optional'] = True
param_info['default'] = self._get_default_value(func_node.args.defaults[default_idx])
params.append(param_info)
# *args parameter
if func_node.args.vararg:
params.append({
'name': f"*{func_node.args.vararg.arg}",
'type': self._get_name(func_node.args.vararg.annotation) if func_node.args.vararg.annotation else 'Any',
'kind': 'var_positional',
'optional': True,
'default': None
})
# Keyword-only arguments (after *)
for i, arg in enumerate(func_node.args.kwonlyargs):
param_info = {
'name': arg.arg,
'type': self._get_name(arg.annotation) if arg.annotation else 'Any',
'kind': 'keyword_only',
'optional': True, # All kwonly args are optional unless explicitly required
'default': None
}
# Check for default value
if i < len(func_node.args.kw_defaults) and func_node.args.kw_defaults[i] is not None:
param_info['default'] = self._get_default_value(func_node.args.kw_defaults[i])
else:
param_info['optional'] = False # No default = required kwonly arg
params.append(param_info)
# **kwargs parameter
if func_node.args.kwarg:
params.append({
'name': f"**{func_node.args.kwarg.arg}",
'type': self._get_name(func_node.args.kwarg.annotation) if func_node.args.kwarg.annotation else 'Dict[str, Any]',
'kind': 'var_keyword',
'optional': True,
'default': None
})
return params
def _get_default_value(self, default_node):
"""Extract default value from AST node"""
try:
if isinstance(default_node, ast.Constant):
return repr(default_node.value)
elif isinstance(default_node, ast.Name):
return default_node.id
elif isinstance(default_node, ast.Attribute):
return self._get_name(default_node)
elif isinstance(default_node, ast.List):
return "[]"
elif isinstance(default_node, ast.Dict):
return "{}"
else:
return "..."
except Exception:
return "..."
def _get_name(self, node):
"""Extract name from AST node, handling complex types safely"""
if node is None:
return "Any"
try:
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
if hasattr(node, 'value'):
return f"{self._get_name(node.value)}.{node.attr}"
else:
return node.attr
elif isinstance(node, ast.Subscript):
# Handle List[Type], Dict[K,V], etc.
base = self._get_name(node.value)
if hasattr(node, 'slice'):
if isinstance(node.slice, ast.Name):
return f"{base}[{node.slice.id}]"
elif isinstance(node.slice, ast.Tuple):
elts = [self._get_name(elt) for elt in node.slice.elts]
return f"{base}[{', '.join(elts)}]"
elif isinstance(node.slice, ast.Constant):
return f"{base}[{repr(node.slice.value)}]"
elif isinstance(node.slice, ast.Attribute):
return f"{base}[{self._get_name(node.slice)}]"
elif isinstance(node.slice, ast.Subscript):
return f"{base}[{self._get_name(node.slice)}]"
else:
# Try to get the name of the slice, fallback to Any if it fails
try:
slice_name = self._get_name(node.slice)
return f"{base}[{slice_name}]"
except:
return f"{base}[Any]"
return base
elif isinstance(node, ast.Constant):
return str(node.value)
elif isinstance(node, ast.Str): # Python < 3.8
return f'"{node.s}"'
elif isinstance(node, ast.Tuple):
elts = [self._get_name(elt) for elt in node.elts]
return f"({', '.join(elts)})"
elif isinstance(node, ast.List):
elts = [self._get_name(elt) for elt in node.elts]
return f"[{', '.join(elts)}]"
else:
# Fallback for complex types - return a simple string representation
return "Any"
except Exception:
# If anything goes wrong, return a safe default
return "Any"
class DirectNeo4jExtractor:
"""Creates nodes and relationships directly in Neo4j"""
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
self.neo4j_uri = neo4j_uri
self.neo4j_user = neo4j_user
self.neo4j_password = neo4j_password
self.driver = None
self.analyzer = Neo4jCodeAnalyzer()
async def initialize(self):
"""Initialize Neo4j connection"""
logger.info("Initializing Neo4j connection...")
self.driver = AsyncGraphDatabase.driver(
self.neo4j_uri,
auth=(self.neo4j_user, self.neo4j_password)
)
# Clear existing data
# logger.info("Clearing existing data...")
# async with self.driver.session() as session:
# await session.run("MATCH (n) DETACH DELETE n")
# Create constraints and indexes
logger.info("Creating constraints and indexes...")
async with self.driver.session() as session:
# Create constraints - using MERGE-friendly approach
await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:File) REQUIRE f.path IS UNIQUE")
await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Class) REQUIRE c.full_name IS UNIQUE")
# Remove unique constraints for methods/attributes since they can be duplicated across classes
# await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (m:Method) REQUIRE m.full_name IS UNIQUE")
# await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:Function) REQUIRE f.full_name IS UNIQUE")
# await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (a:Attribute) REQUIRE a.full_name IS UNIQUE")
# Create indexes for performance
await session.run("CREATE INDEX IF NOT EXISTS FOR (f:File) ON (f.name)")
await session.run("CREATE INDEX IF NOT EXISTS FOR (c:Class) ON (c.name)")
await session.run("CREATE INDEX IF NOT EXISTS FOR (m:Method) ON (m.name)")
logger.info("Neo4j initialized successfully")
async def clear_repository_data(self, repo_name: str):
"""Clear all data for a specific repository"""
logger.info(f"Clearing existing data for repository: {repo_name}")
async with self.driver.session() as session:
# Delete in specific order to avoid constraint issues
# 1. Delete methods and attributes (they depend on classes)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
DETACH DELETE m
""", repo_name=repo_name)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
DETACH DELETE a
""", repo_name=repo_name)
# 2. Delete functions (they depend on files)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
DETACH DELETE func
""", repo_name=repo_name)
# 3. Delete classes (they depend on files)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
DETACH DELETE c
""", repo_name=repo_name)
# 4. Delete files (they depend on repository)
await session.run("""
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)
DETACH DELETE f
""", repo_name=repo_name)
# 5. Finally delete the repository
await session.run("""
MATCH (r:Repository {name: $repo_name})
DETACH DELETE r
""", repo_name=repo_name)
logger.info(f"Cleared data for repository: {repo_name}")
async def close(self):
"""Close Neo4j connection"""
if self.driver:
await self.driver.close()
def clone_repo(self, repo_url: str, target_dir: str) -> str:
"""Clone repository with shallow clone"""
logger.info(f"Cloning repository to: {target_dir}")
if os.path.exists(target_dir):
logger.info(f"Removing existing directory: {target_dir}")
try:
def handle_remove_readonly(func, path, exc):
try:
if os.path.exists(path):
os.chmod(path, 0o777)
func(path)
except PermissionError:
logger.warning(f"Could not remove {path} - file in use, skipping")
pass
shutil.rmtree(target_dir, onerror=handle_remove_readonly)
except Exception as e:
logger.warning(f"Could not fully remove {target_dir}: {e}. Proceeding anyway...")
logger.info(f"Running git clone from {repo_url}")
subprocess.run(['git', 'clone', '--depth', '1', repo_url, target_dir], check=True)
logger.info("Repository cloned successfully")
return target_dir
def get_python_files(self, repo_path: str) -> List[Path]:
"""Get Python files, focusing on main source directories"""
python_files = []
exclude_dirs = {
'tests', 'test', '__pycache__', '.git', 'venv', 'env',
'node_modules', 'build', 'dist', '.pytest_cache', 'docs',
'examples', 'example', 'demo', 'benchmark'
}
for root, dirs, files in os.walk(repo_path):
dirs[:] = [d for d in dirs if d not in exclude_dirs and not d.startswith('.')]
for file in files:
if file.endswith('.py') and not file.startswith('test_'):
file_path = Path(root) / file
if (file_path.stat().st_size < 500_000 and
file not in ['setup.py', 'conftest.py']):
python_files.append(file_path)
return python_files
async def analyze_repository(self, repo_url: str, temp_dir: str = None):
"""Analyze repository and create nodes/relationships in Neo4j"""
repo_name = repo_url.split('/')[-1].replace('.git', '')
logger.info(f"Analyzing repository: {repo_name}")
# Clear existing data for this repository before re-processing
await self.clear_repository_data(repo_name)
# Set default temp_dir to repos folder at script level
if temp_dir is None:
script_dir = Path(__file__).parent
temp_dir = str(script_dir / "repos" / repo_name)
# Clone and analyze
repo_path = Path(self.clone_repo(repo_url, temp_dir))
try:
logger.info("Getting Python files...")
python_files = self.get_python_files(str(repo_path))
logger.info(f"Found {len(python_files)} Python files to analyze")
# First pass: identify project modules
logger.info("Identifying project modules...")
project_modules = set()
for file_path in python_files:
relative_path = str(file_path.relative_to(repo_path))
module_parts = relative_path.replace('/', '.').replace('.py', '').split('.')
if len(module_parts) > 0 and not module_parts[0].startswith('.'):
project_modules.add(module_parts[0])
logger.info(f"Identified project modules: {sorted(project_modules)}")
# Second pass: analyze files and collect data
logger.info("Analyzing Python files...")
modules_data = []
for i, file_path in enumerate(python_files):
if i % 20 == 0:
logger.info(f"Analyzing file {i+1}/{len(python_files)}: {file_path.name}")
analysis = self.analyzer.analyze_python_file(file_path, repo_path, project_modules)
if analysis:
modules_data.append(analysis)
logger.info(f"Found {len(modules_data)} files with content")
# Create nodes and relationships in Neo4j
logger.info("Creating nodes and relationships in Neo4j...")
await self._create_graph(repo_name, modules_data)
# Print summary
total_classes = sum(len(mod['classes']) for mod in modules_data)
total_methods = sum(len(cls['methods']) for mod in modules_data for cls in mod['classes'])
total_functions = sum(len(mod['functions']) for mod in modules_data)
total_imports = sum(len(mod['imports']) for mod in modules_data)
print(f"\\n=== Direct Neo4j Repository Analysis for {repo_name} ===")
print(f"Files processed: {len(modules_data)}")
print(f"Classes created: {total_classes}")
print(f"Methods created: {total_methods}")
print(f"Functions created: {total_functions}")
print(f"Import relationships: {total_imports}")
logger.info(f"Successfully created Neo4j graph for {repo_name}")
finally:
if os.path.exists(temp_dir):
logger.info(f"Cleaning up temporary directory: {temp_dir}")
try:
def handle_remove_readonly(func, path, exc):
try:
if os.path.exists(path):
os.chmod(path, 0o777)
func(path)
except PermissionError:
logger.warning(f"Could not remove {path} - file in use, skipping")
pass
shutil.rmtree(temp_dir, onerror=handle_remove_readonly)
logger.info("Cleanup completed")
except Exception as e:
logger.warning(f"Cleanup failed: {e}. Directory may remain at {temp_dir}")
# Don't fail the whole process due to cleanup issues
async def _create_graph(self, repo_name: str, modules_data: List[Dict]):
"""Create all nodes and relationships in Neo4j"""
async with self.driver.session() as session:
# Create Repository node
await session.run(
"CREATE (r:Repository {name: $repo_name, created_at: datetime()})",
repo_name=repo_name
)
nodes_created = 0
relationships_created = 0
for i, mod in enumerate(modules_data):
# 1. Create File node
await session.run("""
CREATE (f:File {
name: $name,
path: $path,
module_name: $module_name,
line_count: $line_count,
created_at: datetime()
})
""",
name=mod['file_path'].split('/')[-1],
path=mod['file_path'],
module_name=mod['module_name'],
line_count=mod['line_count']
)
nodes_created += 1
# 2. Connect File to Repository
await session.run("""
MATCH (r:Repository {name: $repo_name})
MATCH (f:File {path: $file_path})
CREATE (r)-[:CONTAINS]->(f)
""", repo_name=repo_name, file_path=mod['file_path'])
relationships_created += 1
# 3. Create Class nodes and relationships
for cls in mod['classes']:
# Create Class node using MERGE to avoid duplicates
await session.run("""
MERGE (c:Class {full_name: $full_name})
ON CREATE SET c.name = $name, c.created_at = datetime()
""", name=cls['name'], full_name=cls['full_name'])
nodes_created += 1
# Connect File to Class
await session.run("""
MATCH (f:File {path: $file_path})
MATCH (c:Class {full_name: $class_full_name})
MERGE (f)-[:DEFINES]->(c)
""", file_path=mod['file_path'], class_full_name=cls['full_name'])
relationships_created += 1
# 4. Create Method nodes - use MERGE to avoid duplicates
for method in cls['methods']:
method_full_name = f"{cls['full_name']}.{method['name']}"
# Create method with unique ID to avoid conflicts
method_id = f"{cls['full_name']}::{method['name']}"
await session.run("""
MERGE (m:Method {method_id: $method_id})
ON CREATE SET m.name = $name,
m.full_name = $full_name,
m.args = $args,
m.params_list = $params_list,
m.params_detailed = $params_detailed,
m.return_type = $return_type,
m.created_at = datetime()
""",
name=method['name'],
full_name=method_full_name,
method_id=method_id,
args=method['args'],
params_list=[f"{p['name']}:{p['type']}" for p in method['params']], # Simple format
params_detailed=method.get('params_detailed', []), # Detailed format
return_type=method['return_type']
)
nodes_created += 1
# Connect Class to Method
await session.run("""
MATCH (c:Class {full_name: $class_full_name})
MATCH (m:Method {method_id: $method_id})
MERGE (c)-[:HAS_METHOD]->(m)
""",
class_full_name=cls['full_name'],
method_id=method_id
)
relationships_created += 1
# 5. Create Attribute nodes - use MERGE to avoid duplicates
for attr in cls['attributes']:
attr_full_name = f"{cls['full_name']}.{attr['name']}"
# Create attribute with unique ID to avoid conflicts
attr_id = f"{cls['full_name']}::{attr['name']}"
await session.run("""
MERGE (a:Attribute {attr_id: $attr_id})
ON CREATE SET a.name = $name,
a.full_name = $full_name,
a.type = $type,
a.created_at = datetime()
""",
name=attr['name'],
full_name=attr_full_name,
attr_id=attr_id,
type=attr['type']
)
nodes_created += 1
# Connect Class to Attribute
await session.run("""
MATCH (c:Class {full_name: $class_full_name})
MATCH (a:Attribute {attr_id: $attr_id})
MERGE (c)-[:HAS_ATTRIBUTE]->(a)
""",
class_full_name=cls['full_name'],
attr_id=attr_id
)
relationships_created += 1
# 6. Create Function nodes (top-level) - use MERGE to avoid duplicates
for func in mod['functions']:
func_id = f"{mod['file_path']}::{func['name']}"
await session.run("""
MERGE (f:Function {func_id: $func_id})
ON CREATE SET f.name = $name,
f.full_name = $full_name,
f.args = $args,
f.params_list = $params_list,
f.params_detailed = $params_detailed,
f.return_type = $return_type,
f.created_at = datetime()
""",
name=func['name'],
full_name=func['full_name'],
func_id=func_id,
args=func['args'],
params_list=func.get('params_list', []), # Simple format for backwards compatibility
params_detailed=func.get('params_detailed', []), # Detailed format
return_type=func['return_type']
)
nodes_created += 1
# Connect File to Function
await session.run("""
MATCH (file:File {path: $file_path})
MATCH (func:Function {func_id: $func_id})
MERGE (file)-[:DEFINES]->(func)
""", file_path=mod['file_path'], func_id=func_id)
relationships_created += 1
# 7. Create Import relationships
for import_name in mod['imports']:
# Try to find the target file
await session.run("""
MATCH (source:File {path: $source_path})
OPTIONAL MATCH (target:File)
WHERE target.module_name = $import_name OR target.module_name STARTS WITH $import_name
WITH source, target
WHERE target IS NOT NULL
MERGE (source)-[:IMPORTS]->(target)
""", source_path=mod['file_path'], import_name=import_name)
relationships_created += 1
if (i + 1) % 10 == 0:
logger.info(f"Processed {i + 1}/{len(modules_data)} files...")
logger.info(f"Created {nodes_created} nodes and {relationships_created} relationships")
async def search_graph(self, query_type: str, **kwargs):
"""Search the Neo4j graph directly"""
async with self.driver.session() as session:
if query_type == "files_importing":
target = kwargs.get('target')
result = await session.run("""
MATCH (source:File)-[:IMPORTS]->(target:File)
WHERE target.module_name CONTAINS $target
RETURN source.path as file, target.module_name as imports
""", target=target)
return [{"file": record["file"], "imports": record["imports"]} async for record in result]
elif query_type == "classes_in_file":
file_path = kwargs.get('file_path')
result = await session.run("""
MATCH (f:File {path: $file_path})-[:DEFINES]->(c:Class)
RETURN c.name as class_name, c.full_name as full_name
""", file_path=file_path)
return [{"class_name": record["class_name"], "full_name": record["full_name"]} async for record in result]
elif query_type == "methods_of_class":
class_name = kwargs.get('class_name')
result = await session.run("""
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
WHERE c.name CONTAINS $class_name OR c.full_name CONTAINS $class_name
RETURN m.name as method_name, m.args as args
""", class_name=class_name)
return [{"method_name": record["method_name"], "args": record["args"]} async for record in result]
async def main():
"""Example usage"""
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
extractor = DirectNeo4jExtractor(neo4j_uri, neo4j_user, neo4j_password)
try:
await extractor.initialize()
# Analyze repository - direct Neo4j, no LLM processing!
# repo_url = "https://github.com/pydantic/pydantic-ai.git"
repo_url = "https://github.com/getzep/graphiti.git"
await extractor.analyze_repository(repo_url)
# Direct graph queries
print("\\n=== Direct Neo4j Queries ===")
# Which files import from models?
results = await extractor.search_graph("files_importing", target="models")
print(f"\\nFiles importing from 'models': {len(results)}")
for result in results[:3]:
print(f"- {result['file']} imports {result['imports']}")
# What classes are in a specific file?
results = await extractor.search_graph("classes_in_file", file_path="pydantic_ai/models/openai.py")
print(f"\\nClasses in openai.py: {len(results)}")
for result in results:
print(f"- {result['class_name']}")
# What methods does OpenAIModel have?
results = await extractor.search_graph("methods_of_class", class_name="OpenAIModel")
print(f"\\nMethods of OpenAIModel: {len(results)}")
for result in results[:5]:
print(f"- {result['method_name']}({', '.join(result['args'])})")
finally:
await extractor.close()
if __name__ == "__main__":
asyncio.run(main())
```
--------------------------------------------------------------------------------
/knowledge_graphs/knowledge_graph_validator.py:
--------------------------------------------------------------------------------
```python
"""
Knowledge Graph Validator
Validates AI-generated code against Neo4j knowledge graph containing
repository information. Checks imports, methods, attributes, and parameters.
"""
import asyncio
import logging
from typing import Dict, List, Optional, Set, Tuple, Any
from dataclasses import dataclass, field
from enum import Enum
from neo4j import AsyncGraphDatabase
from ai_script_analyzer import (
AnalysisResult, ImportInfo, MethodCall, AttributeAccess,
FunctionCall, ClassInstantiation
)
logger = logging.getLogger(__name__)
class ValidationStatus(Enum):
VALID = "VALID"
INVALID = "INVALID"
UNCERTAIN = "UNCERTAIN"
NOT_FOUND = "NOT_FOUND"
@dataclass
class ValidationResult:
"""Result of validating a single element"""
status: ValidationStatus
confidence: float # 0.0 to 1.0
message: str
details: Dict[str, Any] = field(default_factory=dict)
suggestions: List[str] = field(default_factory=list)
@dataclass
class ImportValidation:
"""Validation result for an import"""
import_info: ImportInfo
validation: ValidationResult
available_classes: List[str] = field(default_factory=list)
available_functions: List[str] = field(default_factory=list)
@dataclass
class MethodValidation:
"""Validation result for a method call"""
method_call: MethodCall
validation: ValidationResult
expected_params: List[str] = field(default_factory=list)
actual_params: List[str] = field(default_factory=list)
parameter_validation: ValidationResult = None
@dataclass
class AttributeValidation:
"""Validation result for attribute access"""
attribute_access: AttributeAccess
validation: ValidationResult
expected_type: Optional[str] = None
@dataclass
class FunctionValidation:
"""Validation result for function call"""
function_call: FunctionCall
validation: ValidationResult
expected_params: List[str] = field(default_factory=list)
actual_params: List[str] = field(default_factory=list)
parameter_validation: ValidationResult = None
@dataclass
class ClassValidation:
"""Validation result for class instantiation"""
class_instantiation: ClassInstantiation
validation: ValidationResult
constructor_params: List[str] = field(default_factory=list)
parameter_validation: ValidationResult = None
@dataclass
class ScriptValidationResult:
"""Complete validation results for a script"""
script_path: str
analysis_result: AnalysisResult
import_validations: List[ImportValidation] = field(default_factory=list)
class_validations: List[ClassValidation] = field(default_factory=list)
method_validations: List[MethodValidation] = field(default_factory=list)
attribute_validations: List[AttributeValidation] = field(default_factory=list)
function_validations: List[FunctionValidation] = field(default_factory=list)
overall_confidence: float = 0.0
hallucinations_detected: List[Dict[str, Any]] = field(default_factory=list)
class KnowledgeGraphValidator:
"""Validates code against Neo4j knowledge graph"""
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
self.neo4j_uri = neo4j_uri
self.neo4j_user = neo4j_user
self.neo4j_password = neo4j_password
self.driver = None
# Cache for performance
self.module_cache: Dict[str, List[str]] = {}
self.class_cache: Dict[str, Dict[str, Any]] = {}
self.method_cache: Dict[str, List[Dict[str, Any]]] = {}
self.repo_cache: Dict[str, str] = {} # module_name -> repo_name
self.knowledge_graph_modules: Set[str] = set() # Track modules in knowledge graph
async def initialize(self):
"""Initialize Neo4j connection"""
self.driver = AsyncGraphDatabase.driver(
self.neo4j_uri,
auth=(self.neo4j_user, self.neo4j_password)
)
logger.info("Knowledge graph validator initialized")
async def close(self):
"""Close Neo4j connection"""
if self.driver:
await self.driver.close()
async def validate_script(self, analysis_result: AnalysisResult) -> ScriptValidationResult:
"""Validate entire script analysis against knowledge graph"""
result = ScriptValidationResult(
script_path=analysis_result.file_path,
analysis_result=analysis_result
)
# Validate imports first (builds context for other validations)
result.import_validations = await self._validate_imports(analysis_result.imports)
# Validate class instantiations
result.class_validations = await self._validate_class_instantiations(
analysis_result.class_instantiations
)
# Validate method calls
result.method_validations = await self._validate_method_calls(
analysis_result.method_calls
)
# Validate attribute accesses
result.attribute_validations = await self._validate_attribute_accesses(
analysis_result.attribute_accesses
)
# Validate function calls
result.function_validations = await self._validate_function_calls(
analysis_result.function_calls
)
# Calculate overall confidence and detect hallucinations
result.overall_confidence = self._calculate_overall_confidence(result)
result.hallucinations_detected = self._detect_hallucinations(result)
return result
async def _validate_imports(self, imports: List[ImportInfo]) -> List[ImportValidation]:
"""Validate all imports against knowledge graph"""
validations = []
for import_info in imports:
validation = await self._validate_single_import(import_info)
validations.append(validation)
return validations
async def _validate_single_import(self, import_info: ImportInfo) -> ImportValidation:
"""Validate a single import"""
# Determine module to search for
search_module = import_info.module if import_info.is_from_import else import_info.name
# Check cache first
if search_module in self.module_cache:
available_files = self.module_cache[search_module]
else:
# Query Neo4j for matching modules
available_files = await self._find_modules(search_module)
self.module_cache[search_module] = available_files
if available_files:
# Get available classes and functions from the module
classes, functions = await self._get_module_contents(search_module)
# Track this module as being in the knowledge graph
self.knowledge_graph_modules.add(search_module)
# Also track the base module for "from X.Y.Z import ..." patterns
if '.' in search_module:
base_module = search_module.split('.')[0]
self.knowledge_graph_modules.add(base_module)
validation = ValidationResult(
status=ValidationStatus.VALID,
confidence=0.9,
message=f"Module '{search_module}' found in knowledge graph",
details={"matched_files": available_files, "in_knowledge_graph": True}
)
return ImportValidation(
import_info=import_info,
validation=validation,
available_classes=classes,
available_functions=functions
)
else:
# External library - mark as such but don't treat as error
validation = ValidationResult(
status=ValidationStatus.UNCERTAIN,
confidence=0.8, # High confidence it's external, not an error
message=f"Module '{search_module}' is external (not in knowledge graph)",
details={"could_be_external": True, "in_knowledge_graph": False}
)
return ImportValidation(
import_info=import_info,
validation=validation
)
async def _validate_class_instantiations(self, instantiations: List[ClassInstantiation]) -> List[ClassValidation]:
"""Validate class instantiations"""
validations = []
for instantiation in instantiations:
validation = await self._validate_single_class_instantiation(instantiation)
validations.append(validation)
return validations
async def _validate_single_class_instantiation(self, instantiation: ClassInstantiation) -> ClassValidation:
"""Validate a single class instantiation"""
class_name = instantiation.full_class_name or instantiation.class_name
# Skip validation for classes not from knowledge graph
if not self._is_from_knowledge_graph(class_name):
validation = ValidationResult(
status=ValidationStatus.UNCERTAIN,
confidence=0.8,
message=f"Skipping validation: '{class_name}' is not from knowledge graph"
)
return ClassValidation(
class_instantiation=instantiation,
validation=validation
)
# Find class in knowledge graph
class_info = await self._find_class(class_name)
if not class_info:
validation = ValidationResult(
status=ValidationStatus.NOT_FOUND,
confidence=0.2,
message=f"Class '{class_name}' not found in knowledge graph"
)
return ClassValidation(
class_instantiation=instantiation,
validation=validation
)
# Check constructor parameters (look for __init__ method)
init_method = await self._find_method(class_name, "__init__")
if init_method:
param_validation = self._validate_parameters(
expected_params=init_method.get('params_list', []),
provided_args=instantiation.args,
provided_kwargs=instantiation.kwargs
)
else:
param_validation = ValidationResult(
status=ValidationStatus.UNCERTAIN,
confidence=0.5,
message="Constructor parameters not found"
)
# Use parameter validation result if it failed
if param_validation.status == ValidationStatus.INVALID:
validation = ValidationResult(
status=ValidationStatus.INVALID,
confidence=param_validation.confidence,
message=f"Class '{class_name}' found but has invalid constructor parameters: {param_validation.message}",
suggestions=param_validation.suggestions
)
else:
validation = ValidationResult(
status=ValidationStatus.VALID,
confidence=0.8,
message=f"Class '{class_name}' found in knowledge graph"
)
return ClassValidation(
class_instantiation=instantiation,
validation=validation,
parameter_validation=param_validation
)
async def _validate_method_calls(self, method_calls: List[MethodCall]) -> List[MethodValidation]:
"""Validate method calls"""
validations = []
for method_call in method_calls:
validation = await self._validate_single_method_call(method_call)
validations.append(validation)
return validations
async def _validate_single_method_call(self, method_call: MethodCall) -> MethodValidation:
"""Validate a single method call"""
class_type = method_call.object_type
if not class_type:
validation = ValidationResult(
status=ValidationStatus.UNCERTAIN,
confidence=0.3,
message=f"Cannot determine object type for '{method_call.object_name}'"
)
return MethodValidation(
method_call=method_call,
validation=validation
)
# Skip validation for classes not from knowledge graph
if not self._is_from_knowledge_graph(class_type):
validation = ValidationResult(
status=ValidationStatus.UNCERTAIN,
confidence=0.8,
message=f"Skipping validation: '{class_type}' is not from knowledge graph"
)
return MethodValidation(
method_call=method_call,
validation=validation
)
# Find method in knowledge graph
method_info = await self._find_method(class_type, method_call.method_name)
if not method_info:
# Check for similar method names
similar_methods = await self._find_similar_methods(class_type, method_call.method_name)
validation = ValidationResult(
status=ValidationStatus.NOT_FOUND,
confidence=0.1,
message=f"Method '{method_call.method_name}' not found on class '{class_type}'",
suggestions=similar_methods
)
return MethodValidation(
method_call=method_call,
validation=validation
)
# Validate parameters
expected_params = method_info.get('params_list', [])
param_validation = self._validate_parameters(
expected_params=expected_params,
provided_args=method_call.args,
provided_kwargs=method_call.kwargs
)
# Use parameter validation result if it failed
if param_validation.status == ValidationStatus.INVALID:
validation = ValidationResult(
status=ValidationStatus.INVALID,
confidence=param_validation.confidence,
message=f"Method '{method_call.method_name}' found but has invalid parameters: {param_validation.message}",
suggestions=param_validation.suggestions
)
else:
validation = ValidationResult(
status=ValidationStatus.VALID,
confidence=0.9,
message=f"Method '{method_call.method_name}' found on class '{class_type}'"
)
return MethodValidation(
method_call=method_call,
validation=validation,
expected_params=expected_params,
actual_params=method_call.args + list(method_call.kwargs.keys()),
parameter_validation=param_validation
)
async def _validate_attribute_accesses(self, attribute_accesses: List[AttributeAccess]) -> List[AttributeValidation]:
"""Validate attribute accesses"""
validations = []
for attr_access in attribute_accesses:
validation = await self._validate_single_attribute_access(attr_access)
validations.append(validation)
return validations
async def _validate_single_attribute_access(self, attr_access: AttributeAccess) -> AttributeValidation:
"""Validate a single attribute access"""
class_type = attr_access.object_type
if not class_type:
validation = ValidationResult(
status=ValidationStatus.UNCERTAIN,
confidence=0.3,
message=f"Cannot determine object type for '{attr_access.object_name}'"
)
return AttributeValidation(
attribute_access=attr_access,
validation=validation
)
# Skip validation for classes not from knowledge graph
if not self._is_from_knowledge_graph(class_type):
validation = ValidationResult(
status=ValidationStatus.UNCERTAIN,
confidence=0.8,
message=f"Skipping validation: '{class_type}' is not from knowledge graph"
)
return AttributeValidation(
attribute_access=attr_access,
validation=validation
)
# Find attribute in knowledge graph
attr_info = await self._find_attribute(class_type, attr_access.attribute_name)
if not attr_info:
# If not found as attribute, check if it's a method (for decorators like @agent.tool)
method_info = await self._find_method(class_type, attr_access.attribute_name)
if method_info:
validation = ValidationResult(
status=ValidationStatus.VALID,
confidence=0.8,
message=f"'{attr_access.attribute_name}' found as method on class '{class_type}' (likely used as decorator)"
)
return AttributeValidation(
attribute_access=attr_access,
validation=validation,
expected_type="method"
)
validation = ValidationResult(
status=ValidationStatus.NOT_FOUND,
confidence=0.2,
message=f"'{attr_access.attribute_name}' not found on class '{class_type}'"
)
return AttributeValidation(
attribute_access=attr_access,
validation=validation
)
validation = ValidationResult(
status=ValidationStatus.VALID,
confidence=0.8,
message=f"Attribute '{attr_access.attribute_name}' found on class '{class_type}'"
)
return AttributeValidation(
attribute_access=attr_access,
validation=validation,
expected_type=attr_info.get('type')
)
async def _validate_function_calls(self, function_calls: List[FunctionCall]) -> List[FunctionValidation]:
"""Validate function calls"""
validations = []
for func_call in function_calls:
validation = await self._validate_single_function_call(func_call)
validations.append(validation)
return validations
async def _validate_single_function_call(self, func_call: FunctionCall) -> FunctionValidation:
"""Validate a single function call"""
func_name = func_call.full_name or func_call.function_name
# Skip validation for functions not from knowledge graph
if func_call.full_name and not self._is_from_knowledge_graph(func_call.full_name):
validation = ValidationResult(
status=ValidationStatus.UNCERTAIN,
confidence=0.8,
message=f"Skipping validation: '{func_name}' is not from knowledge graph"
)
return FunctionValidation(
function_call=func_call,
validation=validation
)
# Find function in knowledge graph
func_info = await self._find_function(func_name)
if not func_info:
validation = ValidationResult(
status=ValidationStatus.NOT_FOUND,
confidence=0.2,
message=f"Function '{func_name}' not found in knowledge graph"
)
return FunctionValidation(
function_call=func_call,
validation=validation
)
# Validate parameters
expected_params = func_info.get('params_list', [])
param_validation = self._validate_parameters(
expected_params=expected_params,
provided_args=func_call.args,
provided_kwargs=func_call.kwargs
)
# Use parameter validation result if it failed
if param_validation.status == ValidationStatus.INVALID:
validation = ValidationResult(
status=ValidationStatus.INVALID,
confidence=param_validation.confidence,
message=f"Function '{func_name}' found but has invalid parameters: {param_validation.message}",
suggestions=param_validation.suggestions
)
else:
validation = ValidationResult(
status=ValidationStatus.VALID,
confidence=0.8,
message=f"Function '{func_name}' found in knowledge graph"
)
return FunctionValidation(
function_call=func_call,
validation=validation,
expected_params=expected_params,
actual_params=func_call.args + list(func_call.kwargs.keys()),
parameter_validation=param_validation
)
def _validate_parameters(self, expected_params: List[str], provided_args: List[str],
provided_kwargs: Dict[str, str]) -> ValidationResult:
"""Validate function/method parameters with comprehensive support"""
if not expected_params:
return ValidationResult(
status=ValidationStatus.UNCERTAIN,
confidence=0.5,
message="Parameter information not available"
)
# Parse expected parameters - handle detailed format
required_positional = []
optional_positional = []
keyword_only_required = []
keyword_only_optional = []
has_varargs = False
has_varkwargs = False
for param in expected_params:
# Handle detailed format: "[keyword_only] name:type=default" or "name:type"
param_clean = param.strip()
# Check for parameter kind prefix
kind = 'positional'
if param_clean.startswith('['):
end_bracket = param_clean.find(']')
if end_bracket > 0:
kind = param_clean[1:end_bracket]
param_clean = param_clean[end_bracket+1:].strip()
# Check for varargs/varkwargs
if param_clean.startswith('*') and not param_clean.startswith('**'):
has_varargs = True
continue
elif param_clean.startswith('**'):
has_varkwargs = True
continue
# Parse name and check if optional
if ':' in param_clean:
param_name = param_clean.split(':')[0]
is_optional = '=' in param_clean
if kind == 'keyword_only':
if is_optional:
keyword_only_optional.append(param_name)
else:
keyword_only_required.append(param_name)
else: # positional
if is_optional:
optional_positional.append(param_name)
else:
required_positional.append(param_name)
# Count provided parameters
provided_positional_count = len(provided_args)
provided_keyword_names = set(provided_kwargs.keys())
# Validate positional arguments
min_required_positional = len(required_positional)
max_allowed_positional = len(required_positional) + len(optional_positional)
if not has_varargs and provided_positional_count > max_allowed_positional:
return ValidationResult(
status=ValidationStatus.INVALID,
confidence=0.8,
message=f"Too many positional arguments: provided {provided_positional_count}, max allowed {max_allowed_positional}"
)
if provided_positional_count < min_required_positional:
return ValidationResult(
status=ValidationStatus.INVALID,
confidence=0.8,
message=f"Too few positional arguments: provided {provided_positional_count}, required {min_required_positional}"
)
# Validate keyword arguments
all_valid_kwarg_names = set(required_positional + optional_positional + keyword_only_required + keyword_only_optional)
invalid_kwargs = provided_keyword_names - all_valid_kwarg_names
if invalid_kwargs and not has_varkwargs:
return ValidationResult(
status=ValidationStatus.INVALID,
confidence=0.7,
message=f"Invalid keyword arguments: {list(invalid_kwargs)}",
suggestions=[f"Valid parameters: {list(all_valid_kwarg_names)}"]
)
# Check required keyword-only arguments
missing_required_kwargs = set(keyword_only_required) - provided_keyword_names
if missing_required_kwargs:
return ValidationResult(
status=ValidationStatus.INVALID,
confidence=0.8,
message=f"Missing required keyword arguments: {list(missing_required_kwargs)}"
)
return ValidationResult(
status=ValidationStatus.VALID,
confidence=0.9,
message="Parameters are valid"
)
# Neo4j Query Methods
async def _find_modules(self, module_name: str) -> List[str]:
"""Find repository matching the module name, then return its files"""
async with self.driver.session() as session:
# First, try to find files with module names that match or start with the search term
module_query = """
MATCH (r:Repository)-[:CONTAINS]->(f:File)
WHERE f.module_name = $module_name
OR f.module_name STARTS WITH $module_name + '.'
OR split(f.module_name, '.')[0] = $module_name
RETURN DISTINCT r.name as repo_name, count(f) as file_count
ORDER BY file_count DESC
LIMIT 5
"""
result = await session.run(module_query, module_name=module_name)
repos_from_modules = []
async for record in result:
repos_from_modules.append(record['repo_name'])
# Also try repository name matching as fallback
repo_query = """
MATCH (r:Repository)
WHERE toLower(r.name) = toLower($module_name)
OR toLower(replace(r.name, '-', '_')) = toLower($module_name)
OR toLower(replace(r.name, '_', '-')) = toLower($module_name)
RETURN r.name as repo_name
ORDER BY
CASE
WHEN toLower(r.name) = toLower($module_name) THEN 1
WHEN toLower(replace(r.name, '-', '_')) = toLower($module_name) THEN 2
WHEN toLower(replace(r.name, '_', '-')) = toLower($module_name) THEN 3
END
LIMIT 5
"""
result = await session.run(repo_query, module_name=module_name)
repos_from_names = []
async for record in result:
repos_from_names.append(record['repo_name'])
# Combine results, prioritizing module-based matches
all_repos = repos_from_modules + [r for r in repos_from_names if r not in repos_from_modules]
if not all_repos:
return []
# Get files from the best matching repository
best_repo = all_repos[0]
files_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)
RETURN f.path, f.module_name
LIMIT 50
"""
result = await session.run(files_query, repo_name=best_repo)
files = []
async for record in result:
files.append(record['f.path'])
return files
async def _get_module_contents(self, module_name: str) -> Tuple[List[str], List[str]]:
"""Get classes and functions available in a repository matching the module name"""
async with self.driver.session() as session:
# First, try to find repository by module names in files
module_query = """
MATCH (r:Repository)-[:CONTAINS]->(f:File)
WHERE f.module_name = $module_name
OR f.module_name STARTS WITH $module_name + '.'
OR split(f.module_name, '.')[0] = $module_name
RETURN DISTINCT r.name as repo_name, count(f) as file_count
ORDER BY file_count DESC
LIMIT 1
"""
result = await session.run(module_query, module_name=module_name)
record = await result.single()
if record:
repo_name = record['repo_name']
else:
# Fallback to repository name matching
repo_query = """
MATCH (r:Repository)
WHERE toLower(r.name) = toLower($module_name)
OR toLower(replace(r.name, '-', '_')) = toLower($module_name)
OR toLower(replace(r.name, '_', '-')) = toLower($module_name)
RETURN r.name as repo_name
ORDER BY
CASE
WHEN toLower(r.name) = toLower($module_name) THEN 1
WHEN toLower(replace(r.name, '-', '_')) = toLower($module_name) THEN 2
WHEN toLower(replace(r.name, '_', '-')) = toLower($module_name) THEN 3
END
LIMIT 1
"""
result = await session.run(repo_query, module_name=module_name)
record = await result.single()
if not record:
return [], []
repo_name = record['repo_name']
# Get classes from this repository
class_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
RETURN DISTINCT c.name as class_name
"""
result = await session.run(class_query, repo_name=repo_name)
classes = []
async for record in result:
classes.append(record['class_name'])
# Get functions from this repository
func_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
RETURN DISTINCT func.name as function_name
"""
result = await session.run(func_query, repo_name=repo_name)
functions = []
async for record in result:
functions.append(record['function_name'])
return classes, functions
async def _find_repository_for_module(self, module_name: str) -> Optional[str]:
"""Find the repository name that matches a module name"""
if module_name in self.repo_cache:
return self.repo_cache[module_name]
async with self.driver.session() as session:
# First, try to find repository by module names in files
module_query = """
MATCH (r:Repository)-[:CONTAINS]->(f:File)
WHERE f.module_name = $module_name
OR f.module_name STARTS WITH $module_name + '.'
OR split(f.module_name, '.')[0] = $module_name
RETURN DISTINCT r.name as repo_name, count(f) as file_count
ORDER BY file_count DESC
LIMIT 1
"""
result = await session.run(module_query, module_name=module_name)
record = await result.single()
if record:
repo_name = record['repo_name']
else:
# Fallback to repository name matching
query = """
MATCH (r:Repository)
WHERE toLower(r.name) = toLower($module_name)
OR toLower(replace(r.name, '-', '_')) = toLower($module_name)
OR toLower(replace(r.name, '_', '-')) = toLower($module_name)
OR toLower(r.name) CONTAINS toLower($module_name)
OR toLower($module_name) CONTAINS toLower(replace(r.name, '-', '_'))
RETURN r.name as repo_name
ORDER BY
CASE
WHEN toLower(r.name) = toLower($module_name) THEN 1
WHEN toLower(replace(r.name, '-', '_')) = toLower($module_name) THEN 2
ELSE 3
END
LIMIT 1
"""
result = await session.run(query, module_name=module_name)
record = await result.single()
repo_name = record['repo_name'] if record else None
self.repo_cache[module_name] = repo_name
return repo_name
async def _find_class(self, class_name: str) -> Optional[Dict[str, Any]]:
"""Find class information in knowledge graph"""
async with self.driver.session() as session:
# First try exact match
query = """
MATCH (c:Class)
WHERE c.name = $class_name OR c.full_name = $class_name
RETURN c.name as name, c.full_name as full_name
LIMIT 1
"""
result = await session.run(query, class_name=class_name)
record = await result.single()
if record:
return {
'name': record['name'],
'full_name': record['full_name']
}
# If no exact match and class_name has dots, try repository-based search
if '.' in class_name:
parts = class_name.split('.')
module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
class_part = parts[-1] # e.g., "Agent"
# Find repository for the module
repo_name = await self._find_repository_for_module(module_part)
if repo_name:
# Search for class within this repository
repo_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)
WHERE c.name = $class_name
RETURN c.name as name, c.full_name as full_name
LIMIT 1
"""
result = await session.run(repo_query, repo_name=repo_name, class_name=class_part)
record = await result.single()
if record:
return {
'name': record['name'],
'full_name': record['full_name']
}
return None
async def _find_method(self, class_name: str, method_name: str) -> Optional[Dict[str, Any]]:
"""Find method information for a class"""
cache_key = f"{class_name}.{method_name}"
if cache_key in self.method_cache:
methods = self.method_cache[cache_key]
return methods[0] if methods else None
async with self.driver.session() as session:
# First try exact match
query = """
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
WHERE (c.name = $class_name OR c.full_name = $class_name)
AND m.name = $method_name
RETURN m.name as name, m.params_list as params_list, m.params_detailed as params_detailed,
m.return_type as return_type, m.args as args
LIMIT 1
"""
result = await session.run(query, class_name=class_name, method_name=method_name)
record = await result.single()
if record:
# Use detailed params if available, fall back to simple params
params_to_use = record['params_detailed'] or record['params_list'] or []
method_info = {
'name': record['name'],
'params_list': params_to_use,
'return_type': record['return_type'],
'args': record['args'] or []
}
self.method_cache[cache_key] = [method_info]
return method_info
# If no exact match and class_name has dots, try repository-based search
if '.' in class_name:
parts = class_name.split('.')
module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
class_part = parts[-1] # e.g., "Agent"
# Find repository for the module
repo_name = await self._find_repository_for_module(module_part)
if repo_name:
# Search for method within this repository's classes
repo_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
WHERE c.name = $class_name AND m.name = $method_name
RETURN m.name as name, m.params_list as params_list, m.params_detailed as params_detailed,
m.return_type as return_type, m.args as args
LIMIT 1
"""
result = await session.run(repo_query, repo_name=repo_name, class_name=class_part, method_name=method_name)
record = await result.single()
if record:
# Use detailed params if available, fall back to simple params
params_to_use = record['params_detailed'] or record['params_list'] or []
method_info = {
'name': record['name'],
'params_list': params_to_use,
'return_type': record['return_type'],
'args': record['args'] or []
}
self.method_cache[cache_key] = [method_info]
return method_info
self.method_cache[cache_key] = []
return None
async def _find_attribute(self, class_name: str, attr_name: str) -> Optional[Dict[str, Any]]:
"""Find attribute information for a class"""
async with self.driver.session() as session:
# First try exact match
query = """
MATCH (c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
WHERE (c.name = $class_name OR c.full_name = $class_name)
AND a.name = $attr_name
RETURN a.name as name, a.type as type
LIMIT 1
"""
result = await session.run(query, class_name=class_name, attr_name=attr_name)
record = await result.single()
if record:
return {
'name': record['name'],
'type': record['type']
}
# If no exact match and class_name has dots, try repository-based search
if '.' in class_name:
parts = class_name.split('.')
module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
class_part = parts[-1] # e.g., "Agent"
# Find repository for the module
repo_name = await self._find_repository_for_module(module_part)
if repo_name:
# Search for attribute within this repository's classes
repo_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute)
WHERE c.name = $class_name AND a.name = $attr_name
RETURN a.name as name, a.type as type
LIMIT 1
"""
result = await session.run(repo_query, repo_name=repo_name, class_name=class_part, attr_name=attr_name)
record = await result.single()
if record:
return {
'name': record['name'],
'type': record['type']
}
return None
async def _find_function(self, func_name: str) -> Optional[Dict[str, Any]]:
"""Find function information"""
async with self.driver.session() as session:
# First try exact match
query = """
MATCH (f:Function)
WHERE f.name = $func_name OR f.full_name = $func_name
RETURN f.name as name, f.params_list as params_list, f.params_detailed as params_detailed,
f.return_type as return_type, f.args as args
LIMIT 1
"""
result = await session.run(query, func_name=func_name)
record = await result.single()
if record:
# Use detailed params if available, fall back to simple params
params_to_use = record['params_detailed'] or record['params_list'] or []
return {
'name': record['name'],
'params_list': params_to_use,
'return_type': record['return_type'],
'args': record['args'] or []
}
# If no exact match and func_name has dots, try repository-based search
if '.' in func_name:
parts = func_name.split('.')
module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
func_part = parts[-1] # e.g., "some_function"
# Find repository for the module
repo_name = await self._find_repository_for_module(module_part)
if repo_name:
# Search for function within this repository
repo_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function)
WHERE func.name = $func_name
RETURN func.name as name, func.params_list as params_list, func.params_detailed as params_detailed,
func.return_type as return_type, func.args as args
LIMIT 1
"""
result = await session.run(repo_query, repo_name=repo_name, func_name=func_part)
record = await result.single()
if record:
# Use detailed params if available, fall back to simple params
params_to_use = record['params_detailed'] or record['params_list'] or []
return {
'name': record['name'],
'params_list': params_to_use,
'return_type': record['return_type'],
'args': record['args'] or []
}
return None
async def _find_pydantic_ai_result_method(self, method_name: str) -> Optional[Dict[str, Any]]:
"""Find method information for pydantic_ai result objects"""
# Look for methods on pydantic_ai classes that could be result objects
async with self.driver.session() as session:
# Search for common result methods in pydantic_ai repository
query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
WHERE m.name = $method_name
AND (c.name CONTAINS 'Result' OR c.name CONTAINS 'Stream' OR c.name CONTAINS 'Run')
RETURN m.name as name, m.params_list as params_list, m.params_detailed as params_detailed,
m.return_type as return_type, m.args as args, c.name as class_name
LIMIT 1
"""
result = await session.run(query, repo_name="pydantic_ai", method_name=method_name)
record = await result.single()
if record:
# Use detailed params if available, fall back to simple params
params_to_use = record['params_detailed'] or record['params_list'] or []
return {
'name': record['name'],
'params_list': params_to_use,
'return_type': record['return_type'],
'args': record['args'] or [],
'source_class': record['class_name']
}
return None
async def _find_similar_modules(self, module_name: str) -> List[str]:
"""Find similar repository names for suggestions"""
async with self.driver.session() as session:
query = """
MATCH (r:Repository)
WHERE toLower(r.name) CONTAINS toLower($partial_name)
OR toLower(replace(r.name, '-', '_')) CONTAINS toLower($partial_name)
OR toLower(replace(r.name, '_', '-')) CONTAINS toLower($partial_name)
RETURN r.name
LIMIT 5
"""
result = await session.run(query, partial_name=module_name[:3])
suggestions = []
async for record in result:
suggestions.append(record['name'])
return suggestions
async def _find_similar_methods(self, class_name: str, method_name: str) -> List[str]:
"""Find similar method names for suggestions"""
async with self.driver.session() as session:
# First try exact class match
query = """
MATCH (c:Class)-[:HAS_METHOD]->(m:Method)
WHERE (c.name = $class_name OR c.full_name = $class_name)
AND m.name CONTAINS $partial_name
RETURN m.name as name
LIMIT 5
"""
result = await session.run(query, class_name=class_name, partial_name=method_name[:3])
suggestions = []
async for record in result:
suggestions.append(record['name'])
# If no suggestions and class_name has dots, try repository-based search
if not suggestions and '.' in class_name:
parts = class_name.split('.')
module_part = '.'.join(parts[:-1]) # e.g., "pydantic_ai"
class_part = parts[-1] # e.g., "Agent"
# Find repository for the module
repo_name = await self._find_repository_for_module(module_part)
if repo_name:
repo_query = """
MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method)
WHERE c.name = $class_name AND m.name CONTAINS $partial_name
RETURN m.name as name
LIMIT 5
"""
result = await session.run(repo_query, repo_name=repo_name, class_name=class_part, partial_name=method_name[:3])
async for record in result:
suggestions.append(record['name'])
return suggestions
def _calculate_overall_confidence(self, result: ScriptValidationResult) -> float:
"""Calculate overall confidence score for the validation (knowledge graph items only)"""
kg_validations = []
# Only count validations from knowledge graph imports
for val in result.import_validations:
if val.validation.details.get('in_knowledge_graph', False):
kg_validations.append(val.validation.confidence)
# Only count validations from knowledge graph classes
for val in result.class_validations:
class_name = val.class_instantiation.full_class_name or val.class_instantiation.class_name
if self._is_from_knowledge_graph(class_name):
kg_validations.append(val.validation.confidence)
# Only count validations from knowledge graph methods
for val in result.method_validations:
if val.method_call.object_type and self._is_from_knowledge_graph(val.method_call.object_type):
kg_validations.append(val.validation.confidence)
# Only count validations from knowledge graph attributes
for val in result.attribute_validations:
if val.attribute_access.object_type and self._is_from_knowledge_graph(val.attribute_access.object_type):
kg_validations.append(val.validation.confidence)
# Only count validations from knowledge graph functions
for val in result.function_validations:
if val.function_call.full_name and self._is_from_knowledge_graph(val.function_call.full_name):
kg_validations.append(val.validation.confidence)
if not kg_validations:
return 1.0 # No knowledge graph items to validate = perfect confidence
return sum(kg_validations) / len(kg_validations)
def _is_from_knowledge_graph(self, class_type: str) -> bool:
"""Check if a class type comes from a module in the knowledge graph"""
if not class_type:
return False
# For dotted names like "pydantic_ai.Agent" or "pydantic_ai.StreamedRunResult", check the base module
if '.' in class_type:
base_module = class_type.split('.')[0]
# Exact match only - "pydantic" should not match "pydantic_ai"
return base_module in self.knowledge_graph_modules
# For simple names, check if any knowledge graph module matches exactly
# Don't use substring matching to avoid "pydantic" matching "pydantic_ai"
return class_type in self.knowledge_graph_modules
def _detect_hallucinations(self, result: ScriptValidationResult) -> List[Dict[str, Any]]:
"""Detect and categorize hallucinations"""
hallucinations = []
reported_items = set() # Track reported items to avoid duplicates
# Check method calls (only for knowledge graph classes)
for val in result.method_validations:
if (val.validation.status == ValidationStatus.NOT_FOUND and
val.method_call.object_type and
self._is_from_knowledge_graph(val.method_call.object_type)):
# Create unique key to avoid duplicates
key = (val.method_call.line_number, val.method_call.method_name, val.method_call.object_type)
if key not in reported_items:
reported_items.add(key)
hallucinations.append({
'type': 'METHOD_NOT_FOUND',
'location': f"line {val.method_call.line_number}",
'description': f"Method '{val.method_call.method_name}' not found on class '{val.method_call.object_type}'",
'suggestion': val.validation.suggestions[0] if val.validation.suggestions else None
})
# Check attributes (only for knowledge graph classes) - but skip if already reported as method
for val in result.attribute_validations:
if (val.validation.status == ValidationStatus.NOT_FOUND and
val.attribute_access.object_type and
self._is_from_knowledge_graph(val.attribute_access.object_type)):
# Create unique key - if this was already reported as a method, skip it
key = (val.attribute_access.line_number, val.attribute_access.attribute_name, val.attribute_access.object_type)
if key not in reported_items:
reported_items.add(key)
hallucinations.append({
'type': 'ATTRIBUTE_NOT_FOUND',
'location': f"line {val.attribute_access.line_number}",
'description': f"Attribute '{val.attribute_access.attribute_name}' not found on class '{val.attribute_access.object_type}'"
})
# Check parameter issues (only for knowledge graph methods)
for val in result.method_validations:
if (val.parameter_validation and
val.parameter_validation.status == ValidationStatus.INVALID and
val.method_call.object_type and
self._is_from_knowledge_graph(val.method_call.object_type)):
hallucinations.append({
'type': 'INVALID_PARAMETERS',
'location': f"line {val.method_call.line_number}",
'description': f"Invalid parameters for method '{val.method_call.method_name}': {val.parameter_validation.message}"
})
return hallucinations
```