2024-07-01 01:45:59 +00:00
|
|
|
|
# coding=utf-8
|
|
|
|
|
|
"""
|
|
|
|
|
|
@project: maxkb
|
|
|
|
|
|
@Author:虎
|
|
|
|
|
|
@file: workflow_manage.py
|
|
|
|
|
|
@date:2024/1/9 17:40
|
|
|
|
|
|
@desc:
|
|
|
|
|
|
"""
|
2024-08-02 06:21:29 +00:00
|
|
|
|
import json
|
2024-07-01 01:45:59 +00:00
|
|
|
|
from functools import reduce
|
|
|
|
|
|
from typing import List, Dict
|
|
|
|
|
|
|
2024-08-23 09:46:05 +00:00
|
|
|
|
from django.db.models import QuerySet
|
2024-08-02 06:21:29 +00:00
|
|
|
|
from langchain_core.messages import AIMessage
|
2024-07-01 01:45:59 +00:00
|
|
|
|
from langchain_core.prompts import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
from application.flow import tools
|
|
|
|
|
|
from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
|
|
|
|
|
|
from application.flow.step_node import get_node
|
|
|
|
|
|
from common.exception.app_exception import AppApiException
|
2024-08-23 09:46:05 +00:00
|
|
|
|
from setting.models import Model
|
|
|
|
|
|
from setting.models_provider import get_model_credential
|
2024-07-01 01:45:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Edge:
|
|
|
|
|
|
def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
|
|
|
|
|
|
self.id = _id
|
|
|
|
|
|
self.type = _type
|
|
|
|
|
|
self.sourceNodeId = sourceNodeId
|
|
|
|
|
|
self.targetNodeId = targetNodeId
|
|
|
|
|
|
for keyword in keywords:
|
|
|
|
|
|
self.__setattr__(keyword, keywords.get(keyword))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Node:
|
|
|
|
|
|
def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
|
|
|
|
|
|
self.id = _id
|
|
|
|
|
|
self.type = _type
|
|
|
|
|
|
self.x = x
|
|
|
|
|
|
self.y = y
|
|
|
|
|
|
self.properties = properties
|
|
|
|
|
|
for keyword in kwargs:
|
|
|
|
|
|
self.__setattr__(keyword, kwargs.get(keyword))
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-08-15 09:17:25 +00:00
|
|
|
|
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node']
|
2024-07-01 01:45:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Flow:
|
|
|
|
|
|
def __init__(self, nodes: List[Node], edges: List[Edge]):
|
|
|
|
|
|
self.nodes = nodes
|
|
|
|
|
|
self.edges = edges
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def new_instance(flow_obj: Dict):
|
|
|
|
|
|
nodes = flow_obj.get('nodes')
|
|
|
|
|
|
edges = flow_obj.get('edges')
|
|
|
|
|
|
nodes = [Node(node.get('id'), node.get('type'), **node)
|
|
|
|
|
|
for node in nodes]
|
|
|
|
|
|
edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
|
|
|
|
|
|
return Flow(nodes, edges)
|
|
|
|
|
|
|
|
|
|
|
|
def get_start_node(self):
|
|
|
|
|
|
start_node_list = [node for node in self.nodes if node.id == 'start-node']
|
|
|
|
|
|
return start_node_list[0]
|
|
|
|
|
|
|
2024-07-03 03:46:37 +00:00
|
|
|
|
def get_search_node(self):
|
|
|
|
|
|
return [node for node in self.nodes if node.type == 'search-dataset-node']
|
|
|
|
|
|
|
2024-07-01 01:45:59 +00:00
|
|
|
|
def is_valid(self):
|
|
|
|
|
|
"""
|
|
|
|
|
|
校验工作流数据
|
|
|
|
|
|
"""
|
2024-08-23 09:46:05 +00:00
|
|
|
|
self.is_valid_model_params()
|
2024-07-01 01:45:59 +00:00
|
|
|
|
self.is_valid_start_node()
|
|
|
|
|
|
self.is_valid_base_node()
|
|
|
|
|
|
self.is_valid_work_flow()
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def is_valid_node_params(node: Node):
|
|
|
|
|
|
get_node(node.type)(node, None, None)
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid_node(self, node: Node):
|
|
|
|
|
|
self.is_valid_node_params(node)
|
|
|
|
|
|
if node.type == 'condition-node':
|
|
|
|
|
|
branch_list = node.properties.get('node_data').get('branch')
|
|
|
|
|
|
for branch in branch_list:
|
|
|
|
|
|
source_anchor_id = f"{node.id}_{branch.get('id')}_right"
|
|
|
|
|
|
edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id]
|
|
|
|
|
|
if len(edge_list) == 0:
|
|
|
|
|
|
raise AppApiException(500,
|
|
|
|
|
|
f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支需要连接')
|
|
|
|
|
|
elif len(edge_list) > 1:
|
|
|
|
|
|
raise AppApiException(500,
|
|
|
|
|
|
f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支不能连接俩个节点')
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
|
|
|
|
|
|
if len(edge_list) == 0 and not end_nodes.__contains__(node.type):
|
|
|
|
|
|
raise AppApiException(500, f'{node.properties.get("stepName")} 节点不能当做结束节点')
|
|
|
|
|
|
elif len(edge_list) > 1:
|
|
|
|
|
|
raise AppApiException(500,
|
|
|
|
|
|
f'{node.properties.get("stepName")} 节点不能连接俩个节点')
|
|
|
|
|
|
|
|
|
|
|
|
def get_next_nodes(self, node: Node):
|
|
|
|
|
|
edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
|
|
|
|
|
|
node_list = reduce(lambda x, y: [*x, *y],
|
|
|
|
|
|
[[node for node in self.nodes if node.id == edge.targetNodeId] for edge in edge_list],
|
|
|
|
|
|
[])
|
|
|
|
|
|
if len(node_list) == 0 and not end_nodes.__contains__(node.type):
|
|
|
|
|
|
raise AppApiException(500,
|
|
|
|
|
|
f'不存在的下一个节点')
|
|
|
|
|
|
return node_list
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid_work_flow(self, up_node=None):
|
|
|
|
|
|
if up_node is None:
|
|
|
|
|
|
up_node = self.get_start_node()
|
|
|
|
|
|
self.is_valid_node(up_node)
|
|
|
|
|
|
next_nodes = self.get_next_nodes(up_node)
|
|
|
|
|
|
for next_node in next_nodes:
|
|
|
|
|
|
self.is_valid_work_flow(next_node)
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid_start_node(self):
|
|
|
|
|
|
start_node_list = [node for node in self.nodes if node.id == 'start-node']
|
|
|
|
|
|
if len(start_node_list) == 0:
|
|
|
|
|
|
raise AppApiException(500, '开始节点必填')
|
|
|
|
|
|
if len(start_node_list) > 1:
|
|
|
|
|
|
raise AppApiException(500, '开始节点只能有一个')
|
|
|
|
|
|
|
2024-08-23 09:46:05 +00:00
|
|
|
|
def is_valid_model_params(self):
|
|
|
|
|
|
node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')]
|
|
|
|
|
|
for node in node_list:
|
|
|
|
|
|
model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first()
|
|
|
|
|
|
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
|
|
|
|
|
model_params_setting = node.properties.get('node_data', {}).get('model_params_setting')
|
|
|
|
|
|
model_params_setting_form = credential.get_model_params_setting_form(
|
|
|
|
|
|
model.model_name)
|
|
|
|
|
|
if model_params_setting is None:
|
|
|
|
|
|
model_params_setting = model_params_setting_form.get_default_form_data()
|
|
|
|
|
|
node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
|
|
|
|
|
|
model_params_setting_form.valid_form(model_params_setting)
|
|
|
|
|
|
|
2024-07-01 01:45:59 +00:00
|
|
|
|
def is_valid_base_node(self):
|
|
|
|
|
|
base_node_list = [node for node in self.nodes if node.id == 'base-node']
|
|
|
|
|
|
if len(base_node_list) == 0:
|
|
|
|
|
|
raise AppApiException(500, '基本信息节点必填')
|
|
|
|
|
|
if len(base_node_list) > 1:
|
|
|
|
|
|
raise AppApiException(500, '基本信息节点只能有一个')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WorkflowManage:
|
|
|
|
|
|
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler):
|
|
|
|
|
|
self.params = params
|
|
|
|
|
|
self.flow = flow
|
|
|
|
|
|
self.context = {}
|
|
|
|
|
|
self.node_context = []
|
|
|
|
|
|
self.work_flow_post_handler = work_flow_post_handler
|
|
|
|
|
|
self.current_node = None
|
|
|
|
|
|
self.current_result = None
|
2024-08-02 06:21:29 +00:00
|
|
|
|
self.answer = ""
|
2024-07-01 01:45:59 +00:00
|
|
|
|
|
|
|
|
|
|
def run(self):
|
2024-08-02 06:21:29 +00:00
|
|
|
|
if self.params.get('stream'):
|
|
|
|
|
|
return self.run_stream()
|
|
|
|
|
|
return self.run_block()
|
|
|
|
|
|
|
|
|
|
|
|
def run_block(self):
|
2024-07-01 01:45:59 +00:00
|
|
|
|
try:
|
|
|
|
|
|
while self.has_next_node(self.current_result):
|
|
|
|
|
|
self.current_node = self.get_next_node()
|
|
|
|
|
|
self.node_context.append(self.current_node)
|
|
|
|
|
|
self.current_result = self.current_node.run()
|
2024-08-02 06:21:29 +00:00
|
|
|
|
result = self.current_result.write_context(self.current_node, self)
|
|
|
|
|
|
if result is not None:
|
|
|
|
|
|
list(result)
|
|
|
|
|
|
if not self.has_next_node(self.current_result):
|
|
|
|
|
|
return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'],
|
|
|
|
|
|
AIMessage(self.answer), self,
|
|
|
|
|
|
self.work_flow_post_handler)
|
2024-07-01 01:45:59 +00:00
|
|
|
|
except Exception as e:
|
2024-08-02 06:21:29 +00:00
|
|
|
|
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
|
|
|
|
|
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
|
|
|
|
|
|
self.work_flow_post_handler)
|
|
|
|
|
|
|
|
|
|
|
|
def run_stream(self):
|
|
|
|
|
|
return tools.to_stream_response_simple(self.stream_event())
|
|
|
|
|
|
|
|
|
|
|
|
def stream_event(self):
|
|
|
|
|
|
try:
|
|
|
|
|
|
while self.has_next_node(self.current_result):
|
|
|
|
|
|
self.current_node = self.get_next_node()
|
|
|
|
|
|
self.node_context.append(self.current_node)
|
|
|
|
|
|
self.current_result = self.current_node.run()
|
|
|
|
|
|
result = self.current_result.write_context(self.current_node, self)
|
|
|
|
|
|
if result is not None:
|
|
|
|
|
|
for r in result:
|
|
|
|
|
|
if self.is_result():
|
|
|
|
|
|
yield self.get_chunk_content(r)
|
|
|
|
|
|
if not self.has_next_node(self.current_result):
|
|
|
|
|
|
yield self.get_chunk_content('', True)
|
|
|
|
|
|
break
|
|
|
|
|
|
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
|
|
|
|
|
self.answer,
|
|
|
|
|
|
self)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.current_node.get_write_error_context(e)
|
|
|
|
|
|
self.answer += str(e)
|
|
|
|
|
|
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
|
|
|
|
|
self.answer,
|
|
|
|
|
|
self)
|
|
|
|
|
|
yield self.get_chunk_content(str(e), True)
|
|
|
|
|
|
|
|
|
|
|
|
def is_result(self):
|
|
|
|
|
|
"""
|
|
|
|
|
|
判断是否是返回节点
|
|
|
|
|
|
@return:
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self.current_node.node_params.get('is_result', not self.has_next_node(
|
|
|
|
|
|
self.current_result)) if self.current_node.node_params is not None else False
|
|
|
|
|
|
|
|
|
|
|
|
def get_chunk_content(self, chunk, is_end=False):
|
|
|
|
|
|
return 'data: ' + json.dumps(
|
|
|
|
|
|
{'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
|
|
|
|
|
|
'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n"
|
2024-07-01 01:45:59 +00:00
|
|
|
|
|
|
|
|
|
|
def has_next_node(self, node_result: NodeResult | None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
是否有下一个可运行的节点
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self.current_node is None:
|
|
|
|
|
|
if self.get_start_node() is not None:
|
|
|
|
|
|
return True
|
|
|
|
|
|
else:
|
|
|
|
|
|
if node_result is not None and node_result.is_assertion_result():
|
|
|
|
|
|
for edge in self.flow.edges:
|
|
|
|
|
|
if (edge.sourceNodeId == self.current_node.id and
|
|
|
|
|
|
f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
|
|
|
|
|
|
return True
|
|
|
|
|
|
else:
|
|
|
|
|
|
for edge in self.flow.edges:
|
|
|
|
|
|
if edge.sourceNodeId == self.current_node.id:
|
|
|
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def get_runtime_details(self):
|
|
|
|
|
|
details_result = {}
|
|
|
|
|
|
for index in range(len(self.node_context)):
|
|
|
|
|
|
node = self.node_context[index]
|
|
|
|
|
|
details = node.get_details(index)
|
|
|
|
|
|
details_result[node.id] = details
|
|
|
|
|
|
return details_result
|
|
|
|
|
|
|
|
|
|
|
|
def get_next_node(self):
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取下一个可运行的所有节点
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self.current_node is None:
|
|
|
|
|
|
node = self.get_start_node()
|
|
|
|
|
|
node_instance = get_node(node.type)(node, self.params, self.context)
|
|
|
|
|
|
return node_instance
|
|
|
|
|
|
if self.current_result is not None and self.current_result.is_assertion_result():
|
|
|
|
|
|
for edge in self.flow.edges:
|
|
|
|
|
|
if (edge.sourceNodeId == self.current_node.id and
|
|
|
|
|
|
f"{edge.sourceNodeId}_{self.current_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
|
|
|
|
|
|
return self.get_node_cls_by_id(edge.targetNodeId)
|
|
|
|
|
|
else:
|
|
|
|
|
|
for edge in self.flow.edges:
|
|
|
|
|
|
if edge.sourceNodeId == self.current_node.id:
|
|
|
|
|
|
return self.get_node_cls_by_id(edge.targetNodeId)
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_reference_field(self, node_id: str, fields: List[str]):
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@param node_id: 节点id
|
|
|
|
|
|
@param fields: 字段
|
|
|
|
|
|
@return:
|
|
|
|
|
|
"""
|
|
|
|
|
|
if node_id == 'global':
|
|
|
|
|
|
return INode.get_field(self.context, fields)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return self.get_node_by_id(node_id).get_reference_field(fields)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_prompt(self, prompt: str):
|
|
|
|
|
|
"""
|
|
|
|
|
|
格式化生成提示词
|
|
|
|
|
|
@param prompt: 提示词信息
|
|
|
|
|
|
@return: 格式化后的提示词
|
|
|
|
|
|
"""
|
|
|
|
|
|
context = {
|
|
|
|
|
|
'global': self.context,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for node in self.node_context:
|
|
|
|
|
|
properties = node.node.properties
|
|
|
|
|
|
node_config = properties.get('config')
|
|
|
|
|
|
if node_config is not None:
|
|
|
|
|
|
fields = node_config.get('fields')
|
|
|
|
|
|
if fields is not None:
|
|
|
|
|
|
for field in fields:
|
|
|
|
|
|
globeLabel = f"{properties.get('stepName')}.{field.get('value')}"
|
|
|
|
|
|
globeValue = f"context['{node.id}'].{field.get('value')}"
|
|
|
|
|
|
prompt = prompt.replace(globeLabel, globeValue)
|
|
|
|
|
|
global_fields = node_config.get('globalFields')
|
|
|
|
|
|
if global_fields is not None:
|
|
|
|
|
|
for field in global_fields:
|
|
|
|
|
|
globeLabel = f"全局变量.{field.get('value')}"
|
|
|
|
|
|
globeValue = f"context['global'].{field.get('value')}"
|
|
|
|
|
|
prompt = prompt.replace(globeLabel, globeValue)
|
|
|
|
|
|
context[node.id] = node.context
|
|
|
|
|
|
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
|
|
|
|
|
|
|
|
|
|
|
|
value = prompt_template.format(context=context)
|
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
def get_start_node(self):
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取启动节点
|
|
|
|
|
|
@return:
|
|
|
|
|
|
"""
|
|
|
|
|
|
start_node_list = [node for node in self.flow.nodes if node.type == 'start-node']
|
|
|
|
|
|
return start_node_list[0]
|
|
|
|
|
|
|
|
|
|
|
|
def get_node_cls_by_id(self, node_id):
|
|
|
|
|
|
for node in self.flow.nodes:
|
|
|
|
|
|
if node.id == node_id:
|
|
|
|
|
|
node_instance = get_node(node.type)(node,
|
|
|
|
|
|
self.params, self)
|
|
|
|
|
|
return node_instance
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_node_by_id(self, node_id):
|
|
|
|
|
|
for node in self.node_context:
|
|
|
|
|
|
if node.id == node_id:
|
|
|
|
|
|
return node
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def get_node_reference(self, reference_address: Dict):
|
|
|
|
|
|
node = self.get_node_by_id(reference_address.get('node_id'))
|
|
|
|
|
|
return node.context[reference_address.get('node_field')]
|