Browse Source

feat: 数据模型添加AI创建组件

liaojiaxing 1 month ago
parent
commit
080a4f100b

+ 2 - 2
apps/designer/src/hooks/useChat.ts

@@ -236,7 +236,7 @@ export function useChat({ app_name, onStart, onSuccess, onUpdate, onError }: Cha
     onUpdate: (message: ResponseMessageItem) => void;
     onSuccess: (message: ResponseMessageItem) => void;
     onError: (error: Error) => void;
-  }) => {
+  }, name?: string) => {
     setConversationList((list) => {
       return list?.map((item) => {
         return {
@@ -249,7 +249,7 @@ export function useChat({ app_name, onStart, onSuccess, onUpdate, onError }: Cha
       {
         app_name,
         chat_query,
-        chat_name: activeConversation === "1" ? chat_query : undefined,
+        chat_name: activeConversation === "1" ? name || chat_query : undefined,
         conversation_id:
           activeConversation === "1" ? undefined : activeConversation,
       },

+ 42 - 18
apps/designer/src/pages/flow/components/Config/AiCreator.tsx

@@ -9,6 +9,7 @@ import {
 import { Button, Tooltip, Input, Form, Dropdown, message } from "antd";
 import type { DropDownProps } from "antd";
 import { useChat } from "@/hooks/useChat";
+import {  } from "react-markdown"
 
 
 interface ChatHistoryItem {
@@ -36,30 +37,20 @@ export default function AICreator(props: {
   const [input, setInput] = useState("");
   const [messageApi, contextHolder] = message.useMessage();
   const messageKey = 'ailoading';
-  const createContent = useRef<string>('');
+  const msgContent = useRef<string>('');
 
   const {
     loading,
     onRequest,
     cancel
   } = useChat({
-    app_name: "system_design",
-    onStart: () => {
-      messageApi.open({
-        key: messageKey,
-        type: 'loading',
-        content: "AI创作中...",
-        duration: 0,
-        style: {
-          marginTop: 300
-        }
-      })
-    },
+    app_name: "data_model",
     onUpdate: (msg) => {
-      createContent.current += msg.answer;
+      setInput("");
+      msgContent.current += msg.answer;
     },
     onSuccess: (msg) => {
-      console.log('加载完毕!', createContent.current);
+      console.log('加载完毕!', msgContent.current);
       messageApi.open({
         key: messageKey,
         type: 'success',
@@ -84,9 +75,33 @@ export default function AICreator(props: {
     }
   });
 
+  function regexExtractJSON(markdown: string) {
+    const jsonRegex = /```(?:json)?\n([\s\S]*?)\n```/g;
+    const matches = [];
+    let match;
+    
+    while ((match = jsonRegex.exec(markdown)) !== null) {
+        try {
+            const jsonObj = JSON.parse(match[1]);
+            matches.push(jsonObj);
+        } catch (e) {
+            console.warn('无效JSON:', match[0]);
+        }
+    }
+    return matches;
+}
+
   const handleParse = () => {
     try {
-      const json = JSON.parse(createContent.current);
+      // 根据markdown格式取出json部分数据
+      const md = msgContent.current;
+      let json: string;
+      if(md.includes('```json')) {
+        json = regexExtractJSON(msgContent.current)?.[0];
+      } else {
+        json = JSON.parse(msgContent.current);
+      }
+      
       console.log('解析结果:', json);
       props.onChange?.(json);
     } catch (error) {
@@ -130,8 +145,17 @@ export default function AICreator(props: {
   // 处理提交
   const onSubmit = () => {
     if (input.trim()) {
-      onRequest(`设计一个${graphType}, 返回图形json数据, 具体需求描述:${input}`);
-      setInput("");
+      onRequest(`设计一个${graphType}, 返回图形json数据, 具体需求描述:${input}`, undefined, input);
+
+      messageApi.open({
+        key: messageKey,
+        type: 'loading',
+        content: "AI创作中...",
+        duration: 0,
+        style: {
+          marginTop: 300
+        }
+      })
     }
   };
 

+ 3 - 2
apps/er-designer/.umirc.ts

@@ -17,13 +17,14 @@ export default defineConfig({
   ],
   proxy: {
     '/api': {
-      'target': 'http://ab.dev.jbpm.shalu.com/',
+      // 'target': 'http://ab.dev.jbpm.shalu.com/',
+      'target': 'https://edesign.shalu.com/',
       'changeOrigin': true,
       'pathRewrite': { '^/api' : '' },
     },
   },
   scripts: [
-    '//at.alicdn.com/t/c/font_4767192_twj930g7m9m.js'
+    '//at.alicdn.com/t/c/font_4767192_w9at5kfn7o.js'
   ],
   plugins: [
     require.resolve('@umijs/plugins/dist/unocss'),

+ 62 - 0
apps/er-designer/src/api/ai.ts

@@ -0,0 +1,62 @@
+import { request } from "umi";
+
+/**
+ * 获取会话列表
+ * @param app_name 应用名称
+ * @param page_index 页码
+ */
+export const GetSessionList = (params: {
+  app_name: string;
+  page_index: number;
+}) =>
+  request("/api/ai/chat-session/list", {
+    method: "get",
+    params,
+  });
+
+/**
+ * 获取会话消息列表
+ * @param app_name 应用名称
+ * @param session_id 会话id
+ * @param page_index 页码
+ */
+export const GetSessionMessageList = (params: {
+  app_name: string;
+  session_id: string;
+  page_index: number;
+}) =>
+  request("/api/ai/chat-message/list", {
+    method: "get",
+    params,
+  });
+
+
+/**
+ * 修改会话名称
+ * @param app_name 应用名称
+ * @param session_id 会话id
+ * @param page_index 页码
+ */
+export const ChangeSessionName = (data: {
+  app_name: string;
+  session_id: string;
+  new_name: string;
+}) =>
+  request("/api/ai/chat-session/rename", {
+    method: "post",
+    data,
+  });
+
+  /**
+ * 删除会话
+ * @param app_name 应用名称
+ * @param session_id 会话id
+ */
+export const DeleteSession = (data: {
+  app_name: string;
+  session_id: string;
+}) =>
+  request("/api/ai/chat-session/delete", {
+    method: "post",
+    data,
+  });

+ 11 - 0
apps/er-designer/src/api/dataModel.ts

@@ -63,4 +63,15 @@ export const PushDataModelTable = (data: any) => {
     method: "POST",
     data,
   });
+}
+
+/**
+ * 批量添加AI创建结果
+ * @param data 
+ */
+export const BatchAddAICreateResult = (data: any) => {
+  return request("/api/erDiagram/dataModel/doBatchCreateDataModel", {
+    method: "POST",
+    data,
+  });
 }

BIN
apps/er-designer/src/assets/icon-ai-3.png


+ 4 - 0
apps/er-designer/src/global.less

@@ -3,4 +3,8 @@
   vertical-align: -0.15em;
   fill: currentColor;
   overflow: hidden;
+}
+
+.ai-modal-wrapper {
+  pointer-events: none;
 }

+ 325 - 0
apps/er-designer/src/hooks/useChat.ts

@@ -0,0 +1,325 @@
+import { useXAgent, XStream } from "@ant-design/x";
+import { useEffect, useRef, useState } from "react";
+import { useSessionStorageState } from "ahooks";
+import { GetSessionList, GetSessionMessageList } from "@/api/ai";
+
+import type { ConversationsProps } from "@ant-design/x";
+import type { ReactNode } from "react";
+
+// 消息格式
+type MessageItem = {
+  id: string;
+  content: string | ReactNode;
+  role: "user" | "assistant" | "system";
+  status: "loading" | "done" | "error" | "stop";
+  loading?: boolean;
+  footer?: ReactNode;
+};
+
+// 后端返回格式
+type ResponseMessageItem = {
+  answer: string;
+  conversation_id: string;
+  created_at: number;
+  event: "message" | "message_end" | "message_error" | "ping";
+  message_id: string;
+  task_id: string;
+};
+
+type ChatParams = {
+  // 应用名称
+  app_name: string;
+  // 会话内容
+  chat_query: string;
+  // 会话名称 第一次
+  chat_name?: string;
+  // 会话id 后续会话带入
+  conversation_id?: string;
+};
+
+type ChatProps = {
+  // 应用名称
+  app_name: string;
+  // 会话id 后续会话带入
+  conversation_id?: string;
+  // 开始流式传输内容
+  onStart?: (data?: ResponseMessageItem) => void;
+  // 成功获取会话内容
+  onSuccess?: (data: ResponseMessageItem) => void;
+  // 更新流式消息内容
+  onUpdate: (data: ResponseMessageItem) => void;
+  // 异常
+  onError?: (error: Error) => void;
+};
+
+const defaultConversation = {
+  // 会话id
+  key: "1",
+  label: "新的对话",
+};
+
+export function useChat({ app_name, onStart, onSuccess, onUpdate, onError }: ChatProps) {
+  /**
+   * 发送消息加载状态
+   */
+  const [loading, setLoading] = useState(false);
+
+  /**
+   * 加载会话记录列表
+   */
+  const [loadingSession, setLoadingSession] = useState(false);
+
+  /**
+   * 加载消息列表
+   */
+  const [loadingMessages, setLoadingMessages] = useState(false);
+
+  // 用于停止对话
+  const abortController = useRef<AbortController | null>(null);
+
+  /**
+   * 消息列表
+   */
+  const [messages, setMessages] = useState<Array<MessageItem>>([]);
+
+  // 会话列表
+  const [conversationList, setConversationList] = useState<
+    ConversationsProps["items"]
+  >([{ ...defaultConversation }]);
+
+  // 活动对话
+  const [activeConversation, setActiveConversation] = useState("1");
+
+  // 当前智能体对象
+  const [currentAgent, setCurrentAgent] = useSessionStorageState("agent-map");
+
+  useEffect(() => {
+    setLoadingSession(true);
+    GetSessionList({
+      app_name,
+      page_index: 1,
+    })
+      .then((res) => {
+        setConversationList([
+          { ...defaultConversation },
+          ...(res?.result?.model || []).map((item: any) => ({
+            ...item,
+            key: item.sessionId,
+            label: item.name,
+          })),
+        ]);
+      })
+      .finally(() => {
+        setLoadingSession(false);
+      });
+  }, [app_name]);
+
+  /**
+   * 切换会话
+   * @param key 会话id
+   * @returns 
+   */
+  const changeConversation = async (key: string) => {
+    cancel();
+    setActiveConversation(key);
+    if (key === "1") {
+      setMessages([]);
+      return;
+    }
+    setLoadingMessages(true);
+    // 获取会话内容
+    try {
+      const res = await GetSessionMessageList({
+        app_name,
+        session_id: key,
+        page_index: 1,
+      });
+
+      const list: MessageItem[] = [];
+      (res?.result?.model || []).forEach((item: any) => {
+        list.push(
+          {
+            id: item.id + "_query",
+            content: item.query,
+            role: "user",
+            status: "done",
+          },
+          {
+            id: item.id + "_query",
+            content: item.answer,
+            role: "assistant",
+            status: "done",
+          }
+        );
+      });
+      setMessages(list);
+    } finally {
+      setLoadingMessages(false);
+    }
+  };
+
+  const baseUrl = process.env.NODE_ENV === "production" ? "" : "/api";
+
+  /**
+   * 封装智能体
+   */
+  const [agent] = useXAgent<ResponseMessageItem>({
+    request: async (message, { onError, onSuccess, onUpdate }) => {
+      const enterpriseCode = sessionStorage.getItem("enterpriseCode");
+      const token = localStorage.getItem("token_" + enterpriseCode) || '';
+
+      abortController.current = new AbortController();
+      const signal = abortController.current.signal;
+      try {
+        setLoading(true);
+        const response = await fetch(
+          baseUrl + "/api/ai/chat-message",
+          {
+            method: "POST",
+            body: JSON.stringify(message),
+            headers: {
+              Authorization: token,
+              "Content-Type": "application/json",
+            },
+            signal,
+          }
+        );
+
+        // 判断当前是否流式返回
+        if(response.headers.get('content-type')?.includes('text/event-stream')) {
+          if (response.body) {
+            for await (const chunk of XStream({
+              readableStream: response.body,
+            })) {
+              const data = JSON.parse(chunk.data);
+              if (data?.event === "message") {
+                onUpdate(data);
+              }
+              if (data?.event === "message_end") {
+                onSuccess(data);
+              }
+              if (data?.event === "message_error") {
+                onError(data);
+              }
+              if (data?.event === "ping") {
+                console.log(">>>> stream start <<<<");
+                onStart?.(data);
+              }
+            }
+          }
+        } else {
+          // 接口异常处理
+          response.json().then(res => {
+            if(res.code === 0 ) {
+              onError?.(Error(res?.error || '请求失败'));
+              cancel();
+            }
+          });
+        }
+      } catch (error) {
+        // 判断是不是 abort 错误
+        if (signal.aborted) {
+          return;
+        }
+        onError(error as Error);
+      } finally {
+        setLoading(false);
+      }
+    },
+  });
+
+  /**
+   * 发起请求
+   * @param chat_query 对话内容
+   */
+  const onRequest = (chat_query: string, callbacks?: {
+    onUpdate: (message: ResponseMessageItem) => void;
+    onSuccess: (message: ResponseMessageItem) => void;
+    onError: (error: Error) => void;
+  }, name?: string) => {
+    setConversationList((list) => {
+      return list?.map((item) => {
+        return {
+          ...item,
+          label: item.key === "1" ? chat_query : item.label,
+        };
+      });
+    });
+    agent.request(
+      {
+        app_name,
+        chat_query,
+        chat_name: activeConversation === "1" ? name || chat_query : undefined,
+        conversation_id:
+          activeConversation === "1" ? undefined : activeConversation,
+      },
+      callbacks ?? {
+        onSuccess: (data) => {
+          onSuccess?.(data);
+        },
+        onUpdate: (data) => {
+          onUpdate(data);
+          // 更新会话相关信息
+          if (activeConversation === "1") {
+            setConversationList((list) => {
+              return list?.map((item) => {
+                return {
+                  ...item,
+                  // 更新当前会话id
+                  key: item.key === "1" ? data.conversation_id : item.key,
+                };
+              });
+            });
+            setActiveConversation(data.conversation_id);
+          }
+        },
+        onError: (error) => {
+          console.log("error", error);
+          onError?.(error);
+        },
+      }
+    );
+  };
+
+  /**
+   * 停止对话
+   */
+  const cancel = () => {
+    abortController.current?.abort();
+  };
+
+  /**
+   * 新增会话
+   */
+  const addConversation = () => {
+    cancel();
+    setMessages([]);
+    setActiveConversation("1");
+    // 还没产生对话时 直接清除当前对话
+    if (!conversationList?.find((item) => item.key === "1")) {
+      setConversationList([
+        {
+          ...defaultConversation,
+        },
+        ...(conversationList || []),
+      ]);
+    }
+  };
+
+  return {
+    agent,
+    loading,
+    loadingMessages,
+    loadingSession,
+    cancel,
+    messages,
+    setMessages,
+    conversationList,
+    setConversationList,
+    activeConversation,
+    setActiveConversation,
+    onRequest,
+    addConversation,
+    changeConversation,
+  };
+}

+ 28 - 2
apps/er-designer/src/models/erModel.tsx

@@ -6,9 +6,9 @@ import { Snapline } from "@antv/x6-plugin-snapline";
 import { Keyboard } from "@antv/x6-plugin-keyboard";
 import { Export } from "@antv/x6-plugin-export";
 import { Selection } from "@antv/x6-plugin-selection";
-import { SaveDataModel, UploadFile } from "@/api";
+import { SaveDataModel, UploadFile, BatchAddAICreateResult } from "@/api";
 import { useFullscreen, useSessionStorageState, useLocalStorageState } from "ahooks";
-import { createTable } from "@/utils";
+import { createTable, createColumn } from "@/utils";
 import dayjs from "dayjs";
 import { getClassRules, base64ToFile, uuid } from "@repo/utils";
 
@@ -45,6 +45,8 @@ export default function erModel() {
       listenStorageChange: true,
     }
   );
+  // 更新画布标识
+  const [updateKey, setUpdateKey] = useSessionStorageState('update-key', { listenStorageChange: true, defaultValue: 0 });
   const [saveTime, setSaveTime] = useState<string>();
   const [project, setProjectInfo] = useState<ProjectInfo>({
     id: "",
@@ -990,6 +992,29 @@ export default function erModel() {
     });
   };
 
+  /* AI创作 返回结果
+    {
+      tables: [{table: {}, tableColumnList: []}],
+      relations: []
+    } 
+  */
+  const onCreateByAi = async (data: any) => {
+    console.log(data);
+    // if(data?.tables?.length) {
+    //   data.tables.forEach((tableItem: TableItemType) => {
+    //     const newTable = createTable(project.type || 3, project.id);
+    //     merge(newTable.table, tableItem.table);
+    //     table
+    //   })
+    // }
+    await BatchAddAICreateResult({
+      ...data,
+      dataModelId: project.id
+    });
+
+    setUpdateKey((state) => (state || 0) + 1);
+  }
+
   return {
     initGraph,
     graph,
@@ -1025,5 +1050,6 @@ export default function erModel() {
     onSave,
     tableActive,
     setTableActive,
+    onCreateByAi
   };
 }

+ 37 - 7
apps/er-designer/src/pages/detail/index.tsx

@@ -8,7 +8,7 @@ import {
   Tooltip,
   Empty,
   Spin,
-  Modal
+  Modal,
 } from "antd";
 import { ProDescriptions } from "@ant-design/pro-components";
 import type { DescriptionsProps, MenuProps } from "antd";
@@ -27,6 +27,7 @@ import LangInput from "@/components/LangInput";
 import LangInputTextarea from "@/components/LangInputTextarea";
 import { validateAliasName, validateTableCode } from "@/utils/validator";
 import SyncModal from "@/components/SyncModal";
+import AICreator from "@/pages/er/components/AICreator";
 
 const { Content, Header } = Layout;
 export default function index() {
@@ -40,6 +41,7 @@ export default function index() {
     exitPlayMode,
     updateTable,
     graph,
+    onCreateByAi
   } = useModel("erModel");
   const [searchKeyword, setSearchKeyword] = useState("");
   const [selectKey, setSelectKey] = useState<string>(
@@ -98,7 +100,7 @@ export default function index() {
     "er-hideDefaultColumn",
     {
       defaultValue: false,
-      listenStorageChange: true
+      listenStorageChange: true,
     }
   );
 
@@ -266,24 +268,29 @@ export default function index() {
   const extra = (
     <div className="flex gap-12px">
       <a onClick={handleSync}>
-        <i className="iconfont icon-tongbu text-12px" />
+        <i className="iconfont icon-tongbu text-12px mr-4px" />
         数据表同步
       </a>
       <a onClick={() => project.id && addModelRef.current?.edit(project)}>
-        <i className="iconfont icon-bianji text-12px" />
+        <i className="iconfont icon-bianji text-12px mr-4px" />
         基础信息
       </a>
       <a onClick={handleEnterEdit}>
-        <i className="iconfont icon-bianji text-12px" />
+        <i className="iconfont icon-bianji text-12px mr-4px" />
         模型编辑
       </a>
       <a>
-        <i className="iconfont icon-moban text-14px" />
+        <i className="iconfont icon-moban text-14px mr-4px" />
         保存为模板
       </a>
     </div>
   );
 
+  const handleAiCreate = async (data: any) => {
+    await onCreateByAi(data);
+    refresh();
+  }
+
   return (
     <Spin spinning={loading}>
       {/* 基础信息修改弹窗 */}
@@ -294,7 +301,7 @@ export default function index() {
         }}
       />
       {/* 同步弹窗 */}
-      <SyncModal ref={syncModalRef} onPush={refresh}/>
+      <SyncModal ref={syncModalRef} onPush={refresh} />
       <Layout className="h-100vh flex flex-col bg-#fafafa p-12px">
         <Header
           className="shadow-sm"
@@ -391,6 +398,29 @@ export default function index() {
                 >
                   实体
                 </Button>
+                <AICreator
+                  position={{
+                    bottom: 10,
+                    right: 10,
+                    top: "auto"
+                  }}
+                  onChange={handleAiCreate}
+                  trigger={
+                    <Button
+                      type="text"
+                      icon={
+                        <svg
+                          className="icon color-#666"
+                          aria-hidden="true"
+                        >
+                          <use xlinkHref="#icon-AI1"></use>
+                        </svg>
+                      }
+                    >
+                      AI助手
+                    </Button>
+                  }
+                />
               </div>
               <div className="right flex gap-8px m-b-12px">
                 {active === 0 ? (

+ 278 - 0
apps/er-designer/src/pages/er/components/AICreator.tsx

@@ -0,0 +1,278 @@
+import React, { useMemo, useRef, useState } from "react";
+import { Modal, message } from "antd";
+import type { DraggableData, DraggableEvent } from "react-draggable";
+import Draggable from "react-draggable";
+import { Sender, Welcome, Prompts } from "@ant-design/x";
+import { PromptsProps } from "@ant-design/x";
+import aiLogo from "@/assets/icon-ai-3.png";
+import { CoffeeOutlined, FireOutlined, SmileOutlined } from "@ant-design/icons";
+import { useChat } from "@/hooks/useChat";
+
+type AICteatorProps = {
+  trigger: JSX.Element;
+  onChange?: (data: any) => void;
+  onError?: (err: Error) => void;
+  position?: {
+    top?: number | string;
+    left?: number | string;
+    bottom?: number | string;
+    right?: number | string;
+  };
+};
+
+const items: PromptsProps["items"] = [
+  {
+    key: "6",
+    icon: <CoffeeOutlined style={{ color: "#964B00" }} />,
+    description: "帮我创建一个用户表",
+    disabled: false,
+  },
+  {
+    key: "7",
+    icon: <SmileOutlined style={{ color: "#FAAD14" }} />,
+    description: "创建一个订单表",
+    disabled: false,
+  },
+  // {
+  //   key: "8",
+  //   icon: <FireOutlined style={{ color: "#FF4D4F" }} />,
+  //   description: "创建一个商品表",
+  //   disabled: false,
+  // },
+];
+
+export default (props: AICteatorProps) => {
+  const [open, setOpen] = useState(false);
+  const [disabled, setDisabled] = useState(true);
+  const [bounds, setBounds] = useState({
+    left: 0,
+    top: 0,
+    bottom: 0,
+    right: 0,
+  });
+  const [input, setInput] = useState("");
+  const draggleRef = useRef<HTMLDivElement>(null!);
+  const [messageApi, contextHolder] = message.useMessage();
+  const msgContent = useRef<string>("");
+  const messageKey = "data-model";
+
+  function regexExtractJSON(markdown: string) {
+    const jsonRegex = /```(?:json)?\n([\s\S]*?)\n```/g;
+    const matches = [];
+    let match;
+
+    while ((match = jsonRegex.exec(markdown)) !== null) {
+      try {
+        const jsonObj = JSON.parse(match[1]);
+        matches.push(jsonObj);
+      } catch (e) {
+        console.warn("无效JSON:", match[0]);
+      }
+    }
+    return matches;
+  }
+
+  const handleParse = () => {
+    try {
+      // 根据markdown格式取出json部分数据
+      const md = msgContent.current;
+      let json: string;
+      if (md.includes("```json")) {
+        json = regexExtractJSON(msgContent.current)?.[0];
+      } else {
+        json = JSON.parse(msgContent.current);
+      }
+
+      console.log("解析结果:", json);
+      props.onChange?.(json);
+    } catch (error) {
+      messageApi.open({
+        key: messageKey,
+        type: "error",
+        content: "AI创作失败",
+        duration: 2,
+        style: {
+          marginTop: 300,
+        },
+      });
+      console.error(error);
+      props.onError?.(new Error("AI创作失败"));
+    }
+  };
+
+  const { loading, onRequest, cancel } = useChat({
+    app_name: "data_model",
+    onUpdate: (msg) => {
+      setInput("");
+      msgContent.current += msg.answer;
+    },
+    onSuccess: (msg) => {
+      console.log("加载完毕!", msgContent.current);
+      messageApi.open({
+        key: messageKey,
+        type: "success",
+        content: "AI创作完成",
+        duration: 2,
+        style: {
+          marginTop: 300,
+        },
+      });
+      handleParse();
+    },
+    onError: (err) => {
+      messageApi.open({
+        key: messageKey,
+        type: "error",
+        content: err.message || "AI创作失败",
+        duration: 2,
+        style: {
+          marginTop: 300,
+        },
+      });
+    },
+  });
+
+  const triggerDom = React.cloneElement(props.trigger, {
+    ...props.trigger.props,
+    onClick: () => {
+      setOpen(!open);
+    },
+  });
+
+  const onStart = (_event: DraggableEvent, uiData: DraggableData) => {
+    const { clientWidth, clientHeight } = window.document.documentElement;
+    const targetRect = draggleRef.current?.getBoundingClientRect();
+    if (!targetRect) {
+      return;
+    }
+    setBounds({
+      left: -targetRect.left + uiData.x,
+      right: clientWidth - (targetRect.right - uiData.x),
+      top: -targetRect.top + uiData.y,
+      bottom: clientHeight - (targetRect.bottom - uiData.y),
+    });
+  };
+
+  const onSubmit = (value: string) => {
+    if (value.trim()) {
+      const query = `设计一个数据模型内容,需求如下:${value.trim()}`;
+      onRequest(query, undefined, value);
+
+      messageApi.open({
+        key: messageKey,
+        type: "loading",
+        content: (
+          <span>
+            <svg className="icon mr-4px color-#666" aria-hidden="true">
+              <use xlinkHref="#icon-AI1"></use>
+            </svg>
+            <span>AI创作中...</span>
+          </span>
+        ),
+        duration: 0,
+        style: {
+          marginTop: 300,
+        },
+      });
+    }
+  };
+
+  const onStop = () => {
+    cancel();
+    msgContent.current = "";
+    messageApi.open({
+      key: messageKey,
+      type: "error",
+      content: "AI创作已取消",
+      duration: 2,
+      style: {
+        marginTop: 300,
+      },
+    });
+  };
+
+  const handlePromptClick = (item: any) => {
+    const msg = item.data.description || item.data.label;
+    onSubmit(msg);
+  };
+
+  const getStyle: React.CSSProperties = useMemo(
+    () =>
+      props.position
+        ? { position: "absolute", ...props.position }
+        : { position: "absolute", top: 114, right: 18 },
+    [props.position]
+  );
+
+  return (
+    <>
+      {triggerDom}
+      {contextHolder}
+      <Modal
+        title={
+          <div
+            style={{ width: "100%", cursor: "move" }}
+            onMouseOver={() => disabled && setDisabled(false)}
+            onMouseOut={() => setDisabled(true)}
+          >
+            <svg className="icon h-16px w-16px color-#666" aria-hidden="true">
+              <use xlinkHref="#icon-AI1"></use>
+            </svg>
+            <span className="ml-4px">AI助手</span>
+          </div>
+        }
+        mask={false}
+        maskClosable={false}
+        open={open}
+        width={440}
+        style={getStyle}
+        styles={{
+          content: {
+            backgroundImage:
+              "linear-gradient(137deg, #e5f4ff 0%, #efe7ff 100%)",
+          },
+          header: { background: "transparent" },
+        }}
+        footer={null}
+        destroyOnClose
+        wrapClassName="ai-modal-wrapper"
+        onCancel={() => setOpen(false)}
+        modalRender={(modal) => (
+          <Draggable
+            disabled={disabled}
+            bounds={bounds}
+            nodeRef={draggleRef}
+            onStart={(event, uiData) => onStart(event, uiData)}
+          >
+            <div ref={draggleRef}>{modal}</div>
+          </Draggable>
+        )}
+      >
+        <div className="my-10">
+          <Welcome
+            variant="borderless"
+            icon={<img src={aiLogo} className="rounded-lg" alt="AI Logo" />}
+            title="你好,我是数据模型AI助手"
+            description="你需要创建什么的内容,我可以帮你快速生成~"
+          />
+        </div>
+
+        <Prompts
+          className="mb-10"
+          items={items}
+          vertical
+          onItemClick={handlePromptClick}
+        />
+
+        <Sender
+          placeholder="如:创建一个用户表"
+          loading={loading}
+          value={input}
+          onChange={setInput}
+          onSubmit={onSubmit}
+          onCancel={onStop}
+        />
+      </Modal>
+    </>
+  );
+};

+ 19 - 0
apps/er-designer/src/pages/er/components/Toolbar.tsx

@@ -3,6 +3,7 @@ import { Button, Tooltip, Divider, Dropdown } from "antd";
 import { DownOutlined } from "@ant-design/icons";
 import { useModel } from "umi";
 import TodoDrawer from "./TodoDrawer";
+import AICreator from "./AICreator";
 export default function Toolbar() {
   const {
     addTable,
@@ -16,6 +17,7 @@ export default function Toolbar() {
     project,
     setProject,
     onSave,
+    onCreateByAi
   } = useModel("erModel");
   const todoRef = React.useRef<{ open: () => void }>();
   const scaleMenu = {
@@ -192,6 +194,23 @@ export default function Toolbar() {
 
         <div className="h-30px border-transparent border-r-1px border-solid border-r-#eee m-x-8px"></div>
 
+        <div className="group">
+          <div className="flex items-center">
+            <Tooltip title="AI助手">
+              <AICreator 
+                trigger={<div className="btn flex flex-col items-center cursor-pointer py-4px px-10px hover:bg-gray-200">
+                  <svg className="icon h-24px w-24px color-#666" aria-hidden="true">
+                    <use xlinkHref="#icon-AI1"></use>
+                  </svg>
+                </div>}
+                onChange={onCreateByAi}
+              />
+            </Tooltip>
+          </div>
+        </div>
+
+        <div className="h-30px border-transparent border-r-1px border-solid border-r-#eee m-x-8px"></div>
+
         <div className="group">
           <div className="flex items-center">
             <Tooltip title="放大">

+ 10 - 0
apps/er-designer/src/pages/er/index.tsx

@@ -31,6 +31,8 @@ const App: React.FC = () => {
       listenStorageChange: true,
     }
   );
+  // 更新画布key 用于刷新画布
+  const [updateKey, setUpdateKey] = useSessionStorageState('update-key', { listenStorageChange: true, defaultValue: 0 });
   const [show, setShow] = useSessionStorageState("show-navigator");
   const params = useParams();
 
@@ -57,6 +59,14 @@ const App: React.FC = () => {
     }
   }, []);
 
+  useEffect(() => {
+    if (params?.id && updateKey) {
+      run({
+        id: params.id,
+      });
+    }
+  }, [updateKey]);
+
   const tabItems: TabsProps["items"] = [
     {
       key: "1",