所谓非流式响应就是直接等待百度把答案生成好之后直接返回给你,而后者这是一一种流的形式,百度一边生成答案,一边将答案进行返回,这样就是我们在使用ChatGPT中最常见的一种表现了,它回答问题的时候总是一个字一个字的出来。这两回答方式都有一定的使用范围,我认为如果你需要生成的答案不是很多(通过编写对应的prompt进行限制),或者是能够接收长等待,非流式响应是没有问题的。
但是如果你对网络连接请求有一定的要求,如前端使用Uniapp进行编码时,使用uni.uploadFile默认的超时是10s,好像还不能修改超时时间,我是没改成功。。不过这不是关键hh,当进行建立网络连接时,如果客户端超过超时时间还没有接收到服务端的消息,那就会拒绝接收了,即使你只超过零点几秒就生成出了答案,但是客户端还是会拒绝接收,所以这个时候,选择流式响应就是一个必然选择。
本文是将流式回答在Java部分就进行过滤了,或者把流引到前端进行处理会更好,在市面上大多使用SSE技术维护整个对话,因为Uniapp不支持这个技术,所以我使用了websocket进行维护,大致相同
依赖引入:
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.9.3</version>
</dependency>
前端部分:
//断线重连
reconnect() {
if(this.ohHideFlag)
if (!this.is_open_socket) {
this.reconnectTimeOut = setTimeout(() => {
this.connectSocketInit();
}, 3000)
}
},
connectSocketInit() {
let token = getToken()
this.socketTask = uni.connectSocket({
//如果是http则使用ws,如果是https则使用wss,小程序需要去公众平台进行记录
url: 'wss://' + this.socketUrl + '/websocket/' + token,
success: () => {
console.log("正准备建立websocket中...");
// 返回实例
return this.socketTask
},
});
this.socketTask.onOpen((res) => {
console.log("WebSocket连接正常!");
this.is_open_socket = true;
this.socketTask.onMessage((res) => {
if (result == "") {
return;
console.log("回答完毕")
}
let jsonString = res.data
const dataPrefix = "data: ";
if (jsonString.startsWith(dataPrefix)) {
jsonString = jsonString.substring(dataPrefix.length);
}
// 解析JSON字符串
const jsonObject = JSON.parse(jsonString);
// 获取result属性
const result = jsonObject.result;
console.log(result);
this.tempItem.content += result
this.scrollToBottom();
});
})
this.socketTask.onClose(() => {
console.log("已经被关闭了")
this.is_open_socket = false;
this.reconnect();
})
},
后端代码:
package com.farm.controller;
import com.farm.chat.StreamChat;
import lombok.extern.slf4j.Slf4j;
import okhttp3.ResponseBody;
import org.json.JSONException;
import org.json.JSONObject;
import org.springframework.stereotype.Component;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArraySet;
@Slf4j
@Component
@ServerEndpoint("/websocket/{target}") //创建ws的请求路径。
public class WebsocketServerEndpoint {
private Session session;
private String target;
//支持持续流推送
private InputStream inputStream;
private final static CopyOnWriteArraySet<WebsocketServerEndpoint> websockets = new CopyOnWriteArraySet<>();
@OnOpen
public void onOpen(Session session , @PathParam("target") String target){
this.session = session;
this.target = target;
websockets.add(this);
log.info("websocket connect server success , target is {},total is {}",target,websockets.size());
}
//当客户端主动联系就会触发这个方法
@OnMessage
public void onMessage(String message) throws IOException, JSONException {
log.info("message is {}",message);
JSONObject jsonObject = new JSONObject(message);
String user = (String)jsonObject.get("user");
String question = (String)jsonObject.get("message");
StreamChat streamChat = new StreamChat();
ResponseBody body = streamChat.getAnswerStream(question);
InputStream inputStream = body.byteStream();
sendMessageSync(user,inputStream);
}
@OnClose
public void onClose(){
log.info("connection has been closed ,target is {},total is {}" ,this.target, websockets.size());
this.destroy();
}
@OnError
public void onError(Throwable throwable){
this.destroy();
log.info("websocket connect error , target is {} ,total is {}, error is {}",this.target ,websockets.size(),throwable.getMessage());
}
/**
* 根据目标身份推送消息
* @param target
* @param message
* @throws IOException
*/
public void sendMessageOnce(String target, String message) throws IOException {
this.sendMessage(target,message,false,null);
}
/**
* stream 同步日志输出,通过websocket推送至前台。
* @param target
* @param is
* @throws IOException
*/
private void sendMessageSync(String target, InputStream is) throws IOException {
WebsocketServerEndpoint websocket = getWebsocket(target);
if (Objects.isNull(websocket)) {
throw new RuntimeException("The websocket does not exist or has been closed.");
}
if (Objects.isNull(is)) {
throw new RuntimeException("InputStream cannot be null.");
} else {
websocket.inputStream = is;
CompletableFuture.runAsync(websocket::sendMessageWithInputSteam);
}
}
/**
* Send message.
* @param target 通过target获取{@link WebsocketServerEndpoint}.
* @param message message
* @param continuous 是否通过inputStream持续推送消息。
* @param is 输入流
* @throws IOException
*/
private void sendMessage(String target , String message ,Boolean continuous , InputStream is) throws IOException {
WebsocketServerEndpoint websocket = getWebsocket(target);
if(Objects.isNull(websocket)){
throw new RuntimeException("The websocket does not exists or has been closed.");
}
if(continuous){
if(Objects.isNull(is)){
throw new RuntimeException("InputStream can not be null when continuous is true.");
}else{
websocket.inputStream = is;
CompletableFuture.runAsync(websocket::sendMessageWithInputSteam);
}
}else{
websocket.session.getBasicRemote().sendText(message);
}
}
/**
* 通过inputStream 持续推送消息。
* 支持文件、消息、日志等。
*/
private void sendMessageWithInputSteam() {
String message;
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(this.inputStream));
try {
while ((message = bufferedReader.readLine()) != null) {
if(message.equals(""))
continue;
if (websockets.contains(this)) {
System.out.println(message);
this.session.getBasicRemote().sendText(message);
}
}
} catch (IOException e) {
log.warn("SendMessage failed {}", e.getMessage());
} finally {
this.closeInputStream();
}
}
/**
* 根据目标获取对应的{@link WebsocketServerEndpoint}。
* @param target 约定标的
* @return WebsocketServerEndpoint
*/
private WebsocketServerEndpoint getWebsocket(String target){
WebsocketServerEndpoint websocket = null;
for (WebsocketServerEndpoint ws : websockets) {
if (target.equals(ws.target)) {
websocket = ws;
}
}
return websocket;
}
private void closeInputStream(){
if(Objects.nonNull(inputStream)){
try {
inputStream.close();
} catch (Exception e) {
log.warn("websocket close failed {}",e.getMessage());
}
}
}
private void destroy(){
websockets.remove(this);
this.closeInputStream();
}
}
StreamChat
package com.farm.chat;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.web.bind.annotation.GetMapping;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
@Slf4j
public class StreamChat {
//历史对话,需要按照user,assistant
List<Map<String,String>> messages = new ArrayList<>();
private final String ACCESS_TOKEN_URI = "https://aip.baidubce.com/oauth/2.0/token";
private final String CHAT_URI = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-preview";
//这里填入自己的识别码即可
private String apiKey = " ";
private String secretKey = " ";
private int responseTimeOut = 5000;
private OkHttpClient client ;
private String accessToken = "";
public boolean getAccessToken(){
this.client = new OkHttpClient.Builder().readTimeout(responseTimeOut, TimeUnit.SECONDS).build();
MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, "");
//创建一个请求
Request request = new Request.Builder()
.url(ACCESS_TOKEN_URI+"?client_id=" + apiKey + "&client_secret=" + secretKey + "&grant_type=client_credentials")
.method("POST",body)
.addHeader("Content-Type", "application/json")
.build();
try {
//使用浏览器对象发起请求
Response response = client.newCall(request).execute();
//只能执行一次response.body().string()。下次再执行会抛出流关闭异常,因此需要一个对象存储返回结果
String responseMessage = response.body().string();
log.debug("获取accessToken成功");
JSONObject jsonObject = JSON.parseObject(responseMessage);
accessToken = (String) jsonObject.get("access_token");
return true;
} catch (IOException e) {
e.printStackTrace();
}
return false;
}
public ResponseBody getAnswerStream(String question){
getAccessToken();
OkHttpClient client = new OkHttpClient();
HashMap<String, String> user = new HashMap<>();
user.put("role","user");
user.put("content",question);
messages.add(user);
String requestJson = constructRequestJson(1,0.95,0.8,1.0,true,messages);
RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
Request request = new Request.Builder()
.url(CHAT_URI + "?access_token="+accessToken)
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
StringBuilder answer = new StringBuilder();
// 发起异步请求
try {
Response response = client.newCall(request).execute();
// 检查响应是否成功
if (response.isSuccessful()) {
// 获取响应流
return response.body();
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return null;
}
/**
* 构造请求的请求参数
* @param userId
* @param temperature
* @param topP
* @param penaltyScore
* @param messages
* @return
*/
public String constructRequestJson(Integer userId,
Double temperature,
Double topP,
Double penaltyScore,
boolean stream,
List<Map<String, String>> messages) {
Map<String,Object> request = new HashMap<>();
request.put("user_id",userId.toString());
request.put("temperature",temperature);
request.put("top_p",topP);
request.put("penalty_score",penaltyScore);
request.put("stream",stream);
request.put("messages",messages);
System.out.println(JSON.toJSONString(request));
return JSON.toJSONString(request);
}
}
效果如下(或许前端要对字符串进行一个切割,使其有种一个字一个字出来的感觉):