feat: add MCP transport validation to ToolExecutor
parent
29ce72528b
commit
df442272e9
|
|
@ -1,12 +1,13 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
import ast
|
import ast
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
import uuid_utils.compat as uuid
|
import uuid_utils.compat as uuid
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
from maxkb.const import BASE_DIR, CONFIG
|
from maxkb.const import BASE_DIR, CONFIG
|
||||||
from maxkb.const import PROJECT_DIR
|
from maxkb.const import PROJECT_DIR
|
||||||
|
|
@ -210,6 +211,12 @@ exec({dedent(code)!a})
|
||||||
if matched:
|
if matched:
|
||||||
raise Exception(f"keyword '{matched}' is banned in the tool.")
|
raise Exception(f"keyword '{matched}' is banned in the tool.")
|
||||||
|
|
||||||
|
def validate_mcp_transport(self, code_str):
|
||||||
|
servers = json.loads(code_str)
|
||||||
|
for server, config in servers.items():
|
||||||
|
if config.get('transport') not in ['sse', 'streamable_http']:
|
||||||
|
raise Exception(_('Only support transport=sse or transport=streamable_http'))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _exec(_code):
|
def _exec(_code):
|
||||||
return subprocess.run([python_directory, '-c', _code], text=True, capture_output=True)
|
return subprocess.run([python_directory, '-c', _code], text=True, capture_output=True)
|
||||||
|
|
|
||||||
|
|
@ -356,6 +356,7 @@ class ToolSerializer(serializers.Serializer):
|
||||||
ToolCreateRequest(data=instance).is_valid(raise_exception=True)
|
ToolCreateRequest(data=instance).is_valid(raise_exception=True)
|
||||||
# 校验代码是否包括禁止的关键字
|
# 校验代码是否包括禁止的关键字
|
||||||
ToolExecutor().validate_banned_keywords(instance.get('code', ''))
|
ToolExecutor().validate_banned_keywords(instance.get('code', ''))
|
||||||
|
ToolExecutor().validate_mcp_transport(instance.get('code', ''))
|
||||||
|
|
||||||
tool_id = uuid.uuid7()
|
tool_id = uuid.uuid7()
|
||||||
Tool(
|
Tool(
|
||||||
|
|
@ -391,6 +392,8 @@ class ToolSerializer(serializers.Serializer):
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
# 校验代码是否包括禁止的关键字
|
# 校验代码是否包括禁止的关键字
|
||||||
ToolExecutor().validate_banned_keywords(self.data.get('code', ''))
|
ToolExecutor().validate_banned_keywords(self.data.get('code', ''))
|
||||||
|
ToolExecutor().validate_mcp_transport(self.data.get('code', ''))
|
||||||
|
|
||||||
# 校验mcp json
|
# 校验mcp json
|
||||||
validate_mcp_config(json.loads(self.data.get('code')))
|
validate_mcp_config(json.loads(self.data.get('code')))
|
||||||
return True
|
return True
|
||||||
|
|
@ -484,7 +487,7 @@ class ToolSerializer(serializers.Serializer):
|
||||||
ToolEditRequest(data=instance).is_valid(raise_exception=True)
|
ToolEditRequest(data=instance).is_valid(raise_exception=True)
|
||||||
# 校验代码是否包括禁止的关键字
|
# 校验代码是否包括禁止的关键字
|
||||||
ToolExecutor().validate_banned_keywords(instance.get('code', ''))
|
ToolExecutor().validate_banned_keywords(instance.get('code', ''))
|
||||||
|
ToolExecutor().validate_mcp_transport(instance.get('code', ''))
|
||||||
|
|
||||||
if not QuerySet(Tool).filter(id=self.data.get('id')).exists():
|
if not QuerySet(Tool).filter(id=self.data.get('id')).exists():
|
||||||
raise serializers.ValidationError(_('Tool not found'))
|
raise serializers.ValidationError(_('Tool not found'))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue