feat: add MCP transport validation to ToolExecutor

v3.2
CaptainB 2025-09-28 15:33:25 +08:00
parent 29ce72528b
commit df442272e9
2 changed files with 12 additions and 2 deletions

View File

@ -1,12 +1,13 @@
# coding=utf-8 # coding=utf-8
import ast import ast
import os
import json import json
import os
import subprocess import subprocess
import sys import sys
from textwrap import dedent from textwrap import dedent
import uuid_utils.compat as uuid import uuid_utils.compat as uuid
from django.utils.translation import gettext_lazy as _
from maxkb.const import BASE_DIR, CONFIG from maxkb.const import BASE_DIR, CONFIG
from maxkb.const import PROJECT_DIR from maxkb.const import PROJECT_DIR
@ -210,6 +211,12 @@ exec({dedent(code)!a})
if matched: if matched:
raise Exception(f"keyword '{matched}' is banned in the tool.") raise Exception(f"keyword '{matched}' is banned in the tool.")
def validate_mcp_transport(self, code_str):
servers = json.loads(code_str)
for server, config in servers.items():
if config.get('transport') not in ['sse', 'streamable_http']:
raise Exception(_('Only support transport=sse or transport=streamable_http'))
@staticmethod @staticmethod
def _exec(_code): def _exec(_code):
return subprocess.run([python_directory, '-c', _code], text=True, capture_output=True) return subprocess.run([python_directory, '-c', _code], text=True, capture_output=True)

View File

@ -356,6 +356,7 @@ class ToolSerializer(serializers.Serializer):
ToolCreateRequest(data=instance).is_valid(raise_exception=True) ToolCreateRequest(data=instance).is_valid(raise_exception=True)
# 校验代码是否包括禁止的关键字 # 校验代码是否包括禁止的关键字
ToolExecutor().validate_banned_keywords(instance.get('code', '')) ToolExecutor().validate_banned_keywords(instance.get('code', ''))
ToolExecutor().validate_mcp_transport(instance.get('code', ''))
tool_id = uuid.uuid7() tool_id = uuid.uuid7()
Tool( Tool(
@ -391,6 +392,8 @@ class ToolSerializer(serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
# 校验代码是否包括禁止的关键字 # 校验代码是否包括禁止的关键字
ToolExecutor().validate_banned_keywords(self.data.get('code', '')) ToolExecutor().validate_banned_keywords(self.data.get('code', ''))
ToolExecutor().validate_mcp_transport(self.data.get('code', ''))
# 校验mcp json # 校验mcp json
validate_mcp_config(json.loads(self.data.get('code'))) validate_mcp_config(json.loads(self.data.get('code')))
return True return True
@ -484,7 +487,7 @@ class ToolSerializer(serializers.Serializer):
ToolEditRequest(data=instance).is_valid(raise_exception=True) ToolEditRequest(data=instance).is_valid(raise_exception=True)
# 校验代码是否包括禁止的关键字 # 校验代码是否包括禁止的关键字
ToolExecutor().validate_banned_keywords(instance.get('code', '')) ToolExecutor().validate_banned_keywords(instance.get('code', ''))
ToolExecutor().validate_mcp_transport(instance.get('code', ''))
if not QuerySet(Tool).filter(id=self.data.get('id')).exists(): if not QuerySet(Tool).filter(id=self.data.get('id')).exists():
raise serializers.ValidationError(_('Tool not found')) raise serializers.ValidationError(_('Tool not found'))